diff options
-rw-r--r-- | tensorflow/core/platform/cloud/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/gcs_dns_cache.cc | 31 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/gcs_file_system.cc | 21 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/google_auth_provider.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/oauth_client.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/time_util.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/platform/default/build_config.bzl | 1 | ||||
-rw-r--r-- | tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh | 2 | ||||
-rw-r--r-- | third_party/curl.BUILD | 26 |
9 files changed, 91 insertions, 16 deletions
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 624145da75..aaeccc8324 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -10,6 +10,7 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", "tf_cc_test", + "tf_copts", ) filegroup( @@ -29,6 +30,7 @@ filegroup( cc_library( name = "expiring_lru_cache", hdrs = ["expiring_lru_cache.h"], + copts = tf_copts(), visibility = ["//tensorflow:__subpackages__"], deps = ["//tensorflow/core:lib"], ) @@ -37,6 +39,7 @@ cc_library( name = "file_block_cache", srcs = ["file_block_cache.cc"], hdrs = ["file_block_cache.h"], + copts = tf_copts(), visibility = ["//tensorflow:__subpackages__"], deps = ["//tensorflow/core:lib"], ) @@ -45,6 +48,7 @@ cc_library( name = "gcs_dns_cache", srcs = ["gcs_dns_cache.cc"], hdrs = ["gcs_dns_cache.h"], + copts = tf_copts(), visibility = ["//tensorflow:__subpackages__"], deps = [ ":http_request", @@ -56,6 +60,7 @@ cc_library( name = "gcs_file_system", srcs = ["gcs_file_system.cc"], hdrs = ["gcs_file_system.h"], + copts = tf_copts(), linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 visibility = ["//visibility:public"], deps = [ @@ -78,6 +83,7 @@ cc_library( cc_library( name = "http_request", hdrs = ["http_request.h"], + copts = tf_copts(), visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/core:framework_headers_lib", @@ -89,6 +95,7 @@ cc_library( name = "curl_http_request", srcs = ["curl_http_request.cc"], hdrs = ["curl_http_request.h"], + copts = tf_copts(), visibility = ["//tensorflow:__subpackages__"], deps = [ ":http_request", @@ -104,6 +111,7 @@ cc_library( hdrs = [ "http_request_fake.h", ], + copts = tf_copts(), visibility = ["//tensorflow:__subpackages__"], deps = [ ":curl_http_request", @@ -121,6 +129,7 @@ cc_library( "auth_provider.h", "google_auth_provider.h", ], + copts = tf_copts(), visibility = ["//tensorflow:__subpackages__"], deps = [ ":curl_http_request", @@ -136,6 +145,7 @@ cc_library( name = "now_seconds_env", testonly = 1, hdrs = ["now_seconds_env.h"], + copts = tf_copts(), visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/core:lib", @@ -151,6 +161,7 @@ cc_library( hdrs = [ "oauth_client.h", ], + copts = tf_copts(), deps = [ ":curl_http_request", ":http_request", @@ -169,6 +180,7 @@ cc_library( hdrs = [ "retrying_utils.h", ], + copts = tf_copts(), deps = [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_internal", @@ -183,6 +195,7 @@ cc_library( hdrs = [ "retrying_file_system.h", ], + copts = tf_copts(), deps = [ ":retrying_utils", "//tensorflow/core:framework_headers_lib", @@ -198,6 +211,7 @@ cc_library( hdrs = [ "time_util.h", ], + copts = tf_copts(), deps = [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache.cc b/tensorflow/core/platform/cloud/gcs_dns_cache.cc index 63f2da065d..840f2b21cd 100644 --- a/tensorflow/core/platform/cloud/gcs_dns_cache.cc +++ b/tensorflow/core/platform/cloud/gcs_dns_cache.cc @@ -14,9 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/gcs_dns_cache.h" - +#ifndef _WIN32 #include <arpa/inet.h> #include <netdb.h> +#else +#include <winsock2.h> +#include <ws2tcpip.h> +#include <Windows.h> +#endif #include <sys/types.h> namespace tensorflow { @@ -26,6 +31,20 @@ namespace { constexpr char kStorageHost[] = "storage.googleapis.com"; constexpr char kWwwHost[] = "www.googleapis.com"; +inline void print_getaddrinfo_error(const string& name, int error_code) { +#ifndef _WIN32 + if (error_code == EAI_SYSTEM) { + LOG(ERROR) << "Error resolving " << name + << " (EAI_SYSTEM): " << strerror(errno); + } else { + LOG(ERROR) << "Error resolving " << name << ": " + << gai_strerror(error_code); + } +#else + // TODO:WSAGetLastError is better than gai_strerror + LOG(ERROR) << "Error resolving " << name << ": " << gai_strerror(error_code); +#endif +} } // namespace GcsDnsCache::GcsDnsCache(Env* env, int64 refresh_rate_secs) @@ -77,7 +96,7 @@ Status GcsDnsCache::AnnotateRequest(HttpRequest* request) { std::vector<string> output; if (return_code == 0) { - for (addrinfo* i = result; i != nullptr; i = i->ai_next) { + for (const addrinfo* i = result; i != nullptr; i = i->ai_next) { if (i->ai_family != AF_INET || i->ai_addr->sa_family != AF_INET) { LOG(WARNING) << "Non-IPv4 address returned. ai_family: " << i->ai_family << ". sa_family: " << i->ai_addr->sa_family << "."; @@ -96,13 +115,7 @@ Status GcsDnsCache::AnnotateRequest(HttpRequest* request) { } } } else { - if (return_code == EAI_SYSTEM) { - LOG(ERROR) << "Error resolving " << name - << " (EAI_SYSTEM): " << strerror(errno); - } else { - LOG(ERROR) << "Error resolving " << name << ": " - << gai_strerror(return_code); - } + print_getaddrinfo_error(name, return_code); } if (result != nullptr) { freeaddrinfo(result); diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 45e9b05092..c44cad9fc8 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -22,6 +22,9 @@ limitations under the License. #include <cstring> #include <fstream> #include <vector> +#ifdef _WIN32 +#include <io.h> //for _mktemp +#endif #include "include/json/json.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -40,6 +43,12 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/thread_annotations.h" +#ifdef _WIN32 +#ifdef DeleteFile +#undef DeleteFile +#endif +#endif + namespace tensorflow { namespace { @@ -95,16 +104,25 @@ const FileStatistics DIRECTORY_STAT(0, 0, true); // userspace DNS cache. constexpr char kResolveCacheSecs[] = "GCS_RESOLVE_REFRESH_SECS"; +// TODO: DO NOT use a hardcoded path Status GetTmpFilename(string* filename) { if (!filename) { return errors::Internal("'filename' cannot be nullptr."); } +#ifndef _WIN32 char buffer[] = "/tmp/gcs_filesystem_XXXXXX"; int fd = mkstemp(buffer); if (fd < 0) { return errors::Internal("Failed to create a temporary file."); } close(fd); +#else + char buffer[] = "/tmp/gcs_filesystem_XXXXXX"; + char* ret = _mktemp(buffer); + if (ret == nullptr) { + return errors::Internal("Failed to create a temporary file."); + } +#endif *filename = buffer; return Status::OK(); } @@ -292,6 +310,7 @@ class GcsWritableFile : public WritableFile { file_cache_erase_(std::move(file_cache_erase)), sync_needed_(true), initial_retry_delay_usec_(initial_retry_delay_usec) { + // TODO: to make it safer, outfile_ should be constructed from an FD if (GetTmpFilename(&tmp_content_filename_).ok()) { outfile_.open(tmp_content_filename_, std::ofstream::binary | std::ofstream::app); @@ -416,7 +435,7 @@ class GcsWritableFile : public WritableFile { return errors::Internal("'size' cannot be nullptr"); } const auto tellp = outfile_.tellp(); - if (tellp == -1) { + if (tellp == static_cast<std::streampos>(-1)) { return errors::Internal( "Could not get the size of the internal temporary file."); } diff --git a/tensorflow/core/platform/cloud/google_auth_provider.cc b/tensorflow/core/platform/cloud/google_auth_provider.cc index f6fd8373cd..d77f439c5a 100644 --- a/tensorflow/core/platform/cloud/google_auth_provider.cc +++ b/tensorflow/core/platform/cloud/google_auth_provider.cc @@ -14,9 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/google_auth_provider.h" +#ifndef _WIN32 #include <pwd.h> -#include <sys/types.h> #include <unistd.h> +#else +#include <sys/types.h> +#endif #include <fstream> #include "include/json/json.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc index c700b97dc9..3c2830ccd9 100644 --- a/tensorflow/core/platform/cloud/oauth_client.cc +++ b/tensorflow/core/platform/cloud/oauth_client.cc @@ -14,9 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/cloud/oauth_client.h" +#ifndef _WIN32 #include <pwd.h> #include <sys/types.h> #include <unistd.h> +#else +#include <sys/types.h> +#endif #include <fstream> #include <openssl/bio.h> #include <openssl/evp.h> diff --git a/tensorflow/core/platform/cloud/time_util.cc b/tensorflow/core/platform/cloud/time_util.cc index 2f8643f3c7..0587a65c29 100644 --- a/tensorflow/core/platform/cloud/time_util.cc +++ b/tensorflow/core/platform/cloud/time_util.cc @@ -18,6 +18,9 @@ limitations under the License. #include <cmath> #include <cstdio> #include <ctime> +#ifdef _WIN32 +#define timegm _mkgmtime +#endif #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 0f8cf8f122..948334d27b 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -458,7 +458,6 @@ def tf_additional_lib_deps(): def tf_additional_core_deps(): return select({ - "//tensorflow:with_gcp_support_windows_override": [], "//tensorflow:with_gcp_support_android_override": [], "//tensorflow:with_gcp_support_ios_override": [], "//tensorflow:with_gcp_support": [ diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh index 1e455ddc99..8d50250c3a 100644 --- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh +++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh @@ -111,7 +111,7 @@ function run_configure_for_cpu_build { export TF_NEED_MKL=0 fi export TF_NEED_VERBS=0 - export TF_NEED_GCP=0 + export TF_NEED_GCP=1 export TF_NEED_HDFS=0 export TF_NEED_OPENCL_SYCL=0 echo "" | ./configure diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD index e311c7e758..4def6f9489 100644 --- a/third_party/curl.BUILD +++ b/third_party/curl.BUILD @@ -10,6 +10,7 @@ CURL_WIN_COPTS = [ "/DHAVE_CONFIG_H", "/DCURL_DISABLE_FTP", "/DCURL_DISABLE_NTLM", + "/DCURL_DISABLE_PROXY", "/DHAVE_LIBZ", "/DHAVE_ZLIB_H", # Defining _USING_V110_SDK71_ is hackery to defeat curl's incorrect @@ -23,6 +24,8 @@ CURL_WIN_SRCS = [ "lib/asyn-thread.c", "lib/inet_ntop.c", "lib/system_win32.c", + "lib/vtls/schannel.c", + "lib/idn_win32.c", ] cc_library( @@ -276,6 +279,7 @@ cc_library( "-DCURL_MAX_WRITE_SIZE=65536", ], }), + defines = ["CURL_STATICLIB"], includes = ["include"], linkopts = select({ "@org_tensorflow//tensorflow:android": [ @@ -289,10 +293,16 @@ cc_library( ], "@org_tensorflow//tensorflow:ios": [], "@org_tensorflow//tensorflow:windows": [ - "-Wl,ws2_32.lib", + "-DEFAULTLIB:ws2_32.lib", + "-DEFAULTLIB:advapi32.lib", + "-DEFAULTLIB:crypt32.lib", + "-DEFAULTLIB:Normaliz.lib", ], "@org_tensorflow//tensorflow:windows_msvc": [ - "-Wl,ws2_32.lib", + "-DEFAULTLIB:ws2_32.lib", + "-DEFAULTLIB:advapi32.lib", + "-DEFAULTLIB:crypt32.lib", + "-DEFAULTLIB:Normaliz.lib", ], "//conditions:default": [ "-lrt", @@ -438,12 +448,22 @@ genrule( "# include \"lib/config-win32.h\"", "# define BUILDING_LIBCURL 1", "# define CURL_DISABLE_CRYPTO_AUTH 1", + "# define CURL_DISABLE_DICT 1", + "# define CURL_DISABLE_FILE 1", + "# define CURL_DISABLE_GOPHER 1", "# define CURL_DISABLE_IMAP 1", "# define CURL_DISABLE_LDAP 1", "# define CURL_DISABLE_LDAPS 1", "# define CURL_DISABLE_POP3 1", "# define CURL_PULL_WS2TCPIP_H 1", - "# define HTTP_ONLY 1", + "# define CURL_DISABLE_SMTP 1", + "# define CURL_DISABLE_TELNET 1", + "# define CURL_DISABLE_TFTP 1", + "# define CURL_PULL_WS2TCPIP_H 1", + "# define USE_WINDOWS_SSPI 1", + "# define USE_WIN32_IDN 1", + "# define USE_SCHANNEL 1", + "# define WANT_IDN_PROTOTYPES 1", "#elif defined(__APPLE__)", "# define HAVE_FSETXATTR_6 1", "# define HAVE_SETMODE 1", |