diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-07-06 13:50:29 -0700 |
---|---|---|
committer | Yifei Feng <yifeif@google.com> | 2018-07-06 15:17:59 -0700 |
commit | 90fc5e3819ed62e93228a9c2c29dede0f0f8cfd6 (patch) | |
tree | 0e50e14646a382fbdf5edec988f9818bb93b12c0 /tensorflow/core/platform | |
parent | d64754c5c768f26b6a95b350cfd8c7ded2590dc9 (diff) |
Allow is_initialized and initializer to be called on MirroredVariables and TowerLocalVariables.
PiperOrigin-RevId: 203520287
Diffstat (limited to 'tensorflow/core/platform')
-rw-r--r-- | tensorflow/core/platform/s3/s3_crypto.cc | 113 | ||||
-rw-r--r-- | tensorflow/core/platform/s3/s3_crypto.h | 35 | ||||
-rw-r--r-- | tensorflow/core/platform/vmodule_benchmark_test.cc | 28 | ||||
-rw-r--r-- | tensorflow/core/platform/vmodule_test.cc | 117 |
4 files changed, 293 insertions, 0 deletions
diff --git a/tensorflow/core/platform/s3/s3_crypto.cc b/tensorflow/core/platform/s3/s3_crypto.cc new file mode 100644 index 0000000000..d7062a59d2 --- /dev/null +++ b/tensorflow/core/platform/s3/s3_crypto.cc @@ -0,0 +1,113 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/platform/s3/s3_crypto.h" +#include <openssl/hmac.h> +#include <openssl/sha.h> + +#include <aws/core/utils/crypto/HashResult.h> +#include <aws/s3/S3Client.h> + +namespace tensorflow { + +class S3Sha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC { + public: + S3Sha256HMACOpenSSLImpl() {} + + virtual ~S3Sha256HMACOpenSSLImpl() = default; + + virtual Aws::Utils::Crypto::HashResult Calculate( + const Aws::Utils::ByteBuffer& toSign, + const Aws::Utils::ByteBuffer& secret) override { + unsigned int length = SHA256_DIGEST_LENGTH; + Aws::Utils::ByteBuffer digest(length); + memset(digest.GetUnderlyingData(), 0, length); + + HMAC_CTX ctx; + HMAC_CTX_init(&ctx); + + HMAC_Init_ex(&ctx, secret.GetUnderlyingData(), + static_cast<int>(secret.GetLength()), EVP_sha256(), NULL); + HMAC_Update(&ctx, toSign.GetUnderlyingData(), toSign.GetLength()); + HMAC_Final(&ctx, digest.GetUnderlyingData(), &length); + HMAC_CTX_cleanup(&ctx); + + return Aws::Utils::Crypto::HashResult(std::move(digest)); + } +}; + +class S3Sha256OpenSSLImpl : public Aws::Utils::Crypto::Hash { + public: + S3Sha256OpenSSLImpl() {} + + virtual ~S3Sha256OpenSSLImpl() = default; + + virtual Aws::Utils::Crypto::HashResult Calculate( + const Aws::String& str) override { + SHA256_CTX sha256; + SHA256_Init(&sha256); + SHA256_Update(&sha256, str.data(), str.size()); + + Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH); + SHA256_Final(hash.GetUnderlyingData(), &sha256); + + return Aws::Utils::Crypto::HashResult(std::move(hash)); + } + + virtual Aws::Utils::Crypto::HashResult Calculate( + Aws::IStream& stream) override { + SHA256_CTX sha256; + SHA256_Init(&sha256); + + auto currentPos = stream.tellg(); + if (currentPos == std::streampos(std::streamoff(-1))) { + currentPos = 0; + stream.clear(); + } + + stream.seekg(0, stream.beg); + + char streamBuffer + [Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE]; + while (stream.good()) { + stream.read(streamBuffer, + Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE); + auto bytesRead = stream.gcount(); + + if (bytesRead > 0) { + SHA256_Update(&sha256, streamBuffer, static_cast<size_t>(bytesRead)); + } + } + + stream.clear(); + stream.seekg(currentPos, stream.beg); + + Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH); + SHA256_Final(hash.GetUnderlyingData(), &sha256); + + return Aws::Utils::Crypto::HashResult(std::move(hash)); + } +}; + +std::shared_ptr<Aws::Utils::Crypto::Hash> +S3SHA256Factory::CreateImplementation() const { + return Aws::MakeShared<S3Sha256OpenSSLImpl>(S3CryptoAllocationTag); +} + +std::shared_ptr<Aws::Utils::Crypto::HMAC> +S3SHA256HmacFactory::CreateImplementation() const { + return Aws::MakeShared<S3Sha256HMACOpenSSLImpl>(S3CryptoAllocationTag); +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/s3/s3_crypto.h b/tensorflow/core/platform/s3/s3_crypto.h new file mode 100644 index 0000000000..e376b8b0c0 --- /dev/null +++ b/tensorflow/core/platform/s3/s3_crypto.h @@ -0,0 +1,35 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <aws/core/Aws.h> +#include <aws/core/utils/crypto/Factories.h> +#include <aws/core/utils/crypto/HMAC.h> +#include <aws/core/utils/crypto/Hash.h> + +namespace tensorflow { +static const char* S3CryptoAllocationTag = "S3CryptoAllocation"; + +class S3SHA256Factory : public Aws::Utils::Crypto::HashFactory { + public: + std::shared_ptr<Aws::Utils::Crypto::Hash> CreateImplementation() + const override; +}; + +class S3SHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory { + public: + std::shared_ptr<Aws::Utils::Crypto::HMAC> CreateImplementation() + const override; +}; + +} // namespace tensorflow diff --git a/tensorflow/core/platform/vmodule_benchmark_test.cc b/tensorflow/core/platform/vmodule_benchmark_test.cc new file mode 100644 index 0000000000..0f9e75bf9c --- /dev/null +++ b/tensorflow/core/platform/vmodule_benchmark_test.cc @@ -0,0 +1,28 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +static void BM_DisabledVlog(int iters) { + for (int i = 0; i < iters; ++i) { + VLOG(1) << "Testing VLOG(1)!"; + } +} +BENCHMARK(BM_DisabledVlog); + +} // namespace tensorflow diff --git a/tensorflow/core/platform/vmodule_test.cc b/tensorflow/core/platform/vmodule_test.cc new file mode 100644 index 0000000000..47b4b2e0e7 --- /dev/null +++ b/tensorflow/core/platform/vmodule_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Test that popens a child process with the VLOG-ing environment variable set +// for the logging framework, and observes VLOG_IS_ON and VLOG macro output. + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/test.h" + +#include <string.h> + +namespace tensorflow { +namespace { + +int RealMain(const char* argv0, bool do_vlog) { + if (do_vlog) { +#if !defined(PLATFORM_GOOGLE) + // Note, we only test this when !defined(PLATFORM_GOOGLE) because + // VmoduleActivated doesn't exist in that implementation. + // + // Also, we call this internal API to simulate what would happen if + // differently-named translation units attempted to VLOG, so we don't need + // to create dummy translation unit files. + bool ok = internal::LogMessage::VmoduleActivated("vmodule_test.cc", 7) && + internal::LogMessage::VmoduleActivated("shoobadooba.h", 3); + if (!ok) { + fprintf(stderr, "vmodule activated levels not as expected.\n"); + return EXIT_FAILURE; + } +#endif + + // Print info on which VLOG levels are activated. + fprintf(stderr, "VLOG_IS_ON(8)? %d\n", VLOG_IS_ON(8)); + fprintf(stderr, "VLOG_IS_ON(7)? %d\n", VLOG_IS_ON(7)); + fprintf(stderr, "VLOG_IS_ON(6)? %d\n", VLOG_IS_ON(6)); + // Do some VLOG-ing. + VLOG(8) << "VLOG(8)"; + VLOG(7) << "VLOG(7)"; + VLOG(6) << "VLOG(6)"; + LOG(INFO) << "INFO"; + return EXIT_SUCCESS; + } + + // Popen the child process. + std::string command = std::string(argv0); +#if defined(PLATFORM_GOOGLE) + command = command + " do_vlog --vmodule=vmodule_test=7 --alsologtostderr"; +#else + command = + "TF_CPP_VMODULE=vmodule_test=7,shoobadooba=3 " + command + " do_vlog"; +#endif + command += " 2>&1"; + fprintf(stderr, "Running: \"%s\"\n", command.c_str()); + FILE* f = popen(command.c_str(), "r"); + if (f == nullptr) { + fprintf(stderr, "Failed to popen child: %s\n", strerror(errno)); + return EXIT_FAILURE; + } + + // Read data from the child's stdout. + constexpr int kBufferSizeBytes = 4096; + char buffer[kBufferSizeBytes]; + size_t result = fread(buffer, sizeof(buffer[0]), kBufferSizeBytes - 1, f); + if (result == 0) { + fprintf(stderr, "Failed to read from child stdout: %zu %s\n", result, + strerror(errno)); + return EXIT_FAILURE; + } + buffer[result] = '\0'; + int status = pclose(f); + if (status == -1) { + fprintf(stderr, "Failed to close popen child: %s\n", strerror(errno)); + return EXIT_FAILURE; + } + + // Check output is as expected. + const char kExpected[] = + "VLOG_IS_ON(8)? 0\nVLOG_IS_ON(7)? 1\nVLOG_IS_ON(6)? 1\n"; + if (strstr(buffer, kExpected) == nullptr) { + fprintf(stderr, "error: unexpected output from child: \"%.*s\"\n", + kBufferSizeBytes, buffer); + return EXIT_FAILURE; + } + bool ok = strstr(buffer, "VLOG(7)\n") != nullptr && + strstr(buffer, "VLOG(6)\n") != nullptr && + strstr(buffer, "VLOG(8)\n") == nullptr; + if (!ok) { + fprintf(stderr, "error: VLOG output not as expected: \"%.*s\"\n", + kBufferSizeBytes, buffer); + return EXIT_FAILURE; + } + + // Success! + return EXIT_SUCCESS; +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + bool do_vlog = argc >= 2 && strcmp(argv[1], "do_vlog") == 0; + return tensorflow::RealMain(argv[0], do_vlog); +} |