Make SSLStream initialization non-blocking Existing initialization solution was blocking. BUG:24961367 Change-Id: Iec5125013161cfd0c50ede78efdbdb1cb42ebd64 Reviewed-on: https://weave-review.googlesource.com/1362 Reviewed-by: Alex Vakulenko <avakulenko@google.com>
diff --git a/libweave/examples/provider/event_network.cc b/libweave/examples/provider/event_network.cc index 82165be..d3a3391 100644 --- a/libweave/examples/provider/event_network.cc +++ b/libweave/examples/provider/event_network.cc
@@ -100,20 +100,7 @@ void EventNetworkImpl::OpenSslSocket(const std::string& host, uint16_t port, const OpenSslSocketCallback& callback) { - // Connect to SSL port instead of upgrading to TLS. - std::unique_ptr<SSLStream> tls_stream{new SSLStream{task_runner_}}; - - if (tls_stream->Init(host, port)) { - task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(callback, base::Passed(&tls_stream), nullptr), - {}); - } else { - ErrorPtr error; - Error::AddTo(&error, FROM_HERE, "tls", "tls_init_failed", - "Failed to initialize TLS stream."); - task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); - } + SSLStream::Connect(task_runner_, host, port, callback); } } // namespace examples
diff --git a/libweave/examples/provider/ssl_stream.cc b/libweave/examples/provider/ssl_stream.cc index d3cbcfd..eea28c0 100644 --- a/libweave/examples/provider/ssl_stream.cc +++ b/libweave/examples/provider/ssl_stream.cc
@@ -4,29 +4,66 @@ #include "examples/provider/ssl_stream.h" +#include <openssl/err.h> + #include <base/bind.h> +#include <base/bind_helpers.h> #include <weave/provider/task_runner.h> namespace weave { namespace examples { namespace { -int GetSSLError(const SSL* ssl, int ret) { + +void AddSslError(ErrorPtr* error, + const tracked_objects::Location& location, + const std::string& error_code, + unsigned long ssl_error_code) { + ERR_load_BIO_strings(); SSL_load_error_strings(); - return SSL_get_error(ssl, ret); + Error::AddToPrintf(error, location, "ssl_stream", error_code, "%s: %s", + ERR_lib_error_string(ssl_error_code), + ERR_reason_error_string(ssl_error_code)); } + +void RetryAsyncTask(provider::TaskRunner* task_runner, + const tracked_objects::Location& location, + const base::Closure& task) { + task_runner->PostDelayedTask(FROM_HERE, task, + base::TimeDelta::FromMilliseconds(100)); +} + } // namespace -SSLStream::SSLStream(provider::TaskRunner* task_runner) +void SSLStream::SslDeleter::operator()(BIO* bio) const { + BIO_free(bio); +} + +void SSLStream::SslDeleter::operator()(SSL* ssl) const { + SSL_free(ssl); +} + +void SSLStream::SslDeleter::operator()(SSL_CTX* ctx) const { + SSL_CTX_free(ctx); +} + +SSLStream::SSLStream(provider::TaskRunner* task_runner, + std::unique_ptr<BIO, SslDeleter> stream_bio) : task_runner_{task_runner} { - SSL_library_init(); + ctx_.reset(SSL_CTX_new(TLSv1_2_client_method())); + CHECK(ctx_); + ssl_.reset(SSL_new(ctx_.get())); + + SSL_set_bio(ssl_.get(), stream_bio.get(), stream_bio.get()); + stream_bio.release(); // Owned by ssl now. + SSL_set_connect_state(ssl_.get()); } SSLStream::~SSLStream() { CancelPendingOperations(); } -void SSLStream::RunDelayedTask(const base::Closure& task) { +void SSLStream::RunTask(const base::Closure& task) { task.Run(); } @@ -37,31 +74,28 @@ if (res > 0) { task_runner_->PostDelayedTask( FROM_HERE, - base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(), + base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), base::Bind(callback, res, nullptr)), {}); return; } - int err = GetSSLError(ssl_.get(), res); + int err = SSL_get_error(ssl_.get(), res); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { - task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), - buffer, size_to_read, callback), - base::TimeDelta::FromSeconds(1)); - return; + return RetryAsyncTask( + task_runner_, FROM_HERE, + base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer, + size_to_read, callback)); } ErrorPtr weave_error; - Error::AddTo(&weave_error, FROM_HERE, "ssl", "socket_read_failed", - "SSL error"); - task_runner_->PostDelayedTask( + AddSslError(&weave_error, FROM_HERE, "read_failed", err); + return task_runner_->PostDelayedTask( FROM_HERE, - base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(), + base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), base::Bind(callback, 0, base::Passed(&weave_error))), {}); - return; } void SSLStream::Write(const void* buffer, @@ -72,81 +106,101 @@ buffer = static_cast<const char*>(buffer) + res; size_to_write -= res; if (size_to_write == 0) { - task_runner_->PostDelayedTask( + return task_runner_->PostDelayedTask( FROM_HERE, - base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(), + base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), base::Bind(callback, nullptr)), {}); - return; } - task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), - buffer, size_to_write, callback), - base::TimeDelta::FromSeconds(1)); - - return; + return RetryAsyncTask( + task_runner_, FROM_HERE, + base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer, + size_to_write, callback)); } - int err = GetSSLError(ssl_.get(), res); + int err = SSL_get_error(ssl_.get(), res); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { - task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), - buffer, size_to_write, callback), - base::TimeDelta::FromSeconds(1)); - return; + return RetryAsyncTask( + task_runner_, FROM_HERE, + base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer, + size_to_write, callback)); } ErrorPtr weave_error; - Error::AddTo(&weave_error, FROM_HERE, "ssl", "socket_write_failed", - "SSL error"); + AddSslError(&weave_error, FROM_HERE, "write_failed", err); task_runner_->PostDelayedTask( - FROM_HERE, - base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(), - base::Bind(callback, base::Passed(&weave_error))), + FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), + base::Bind(callback, base::Passed(&weave_error))), {}); - return; } void SSLStream::CancelPendingOperations() { weak_ptr_factory_.InvalidateWeakPtrs(); } -bool SSLStream::Init(const std::string& host, uint16_t port) { - ctx_.reset(SSL_CTX_new(TLSv1_2_client_method())); - CHECK(ctx_); - ssl_.reset(SSL_new(ctx_.get())); +void SSLStream::Connect( + provider::TaskRunner* task_runner, + const std::string& host, + uint16_t port, + const provider::Network::OpenSslSocketCallback& callback) { + SSL_library_init(); char end_point[255]; snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port); - BIO* stream_bio = BIO_new_connect(end_point); + + std::unique_ptr<BIO, SslDeleter> stream_bio(BIO_new_connect(end_point)); CHECK(stream_bio); - BIO_set_nbio(stream_bio, 1); + BIO_set_nbio(stream_bio.get(), 1); - while (BIO_do_connect(stream_bio) != 1) { - CHECK(BIO_should_retry(stream_bio)); - sleep(1); + std::unique_ptr<SSLStream> stream{ + new SSLStream{task_runner, std::move(stream_bio)}}; + ConnectBio(std::move(stream), callback); +} + +void SSLStream::ConnectBio( + std::unique_ptr<SSLStream> stream, + const provider::Network::OpenSslSocketCallback& callback) { + BIO* bio = SSL_get_rbio(stream->ssl_.get()); + if (BIO_do_connect(bio) == 1) + return DoHandshake(std::move(stream), callback); + + auto task_runner = stream->task_runner_; + if (BIO_should_retry(bio)) { + return RetryAsyncTask( + task_runner, FROM_HERE, + base::Bind(&SSLStream::ConnectBio, base::Passed(&stream), callback)); } - SSL_set_bio(ssl_.get(), stream_bio, stream_bio); - SSL_set_connect_state(ssl_.get()); + ErrorPtr error; + AddSslError(&error, FROM_HERE, "connect_failed", ERR_get_error()); + task_runner->PostDelayedTask( + FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); +} - for (;;) { - int res = SSL_do_handshake(ssl_.get()); - if (res) { - return true; - } - - res = GetSSLError(ssl_.get(), res); - - if (res != SSL_ERROR_WANT_READ || res != SSL_ERROR_WANT_WRITE) { - return false; - } - - sleep(1); +void SSLStream::DoHandshake( + std::unique_ptr<SSLStream> stream, + const provider::Network::OpenSslSocketCallback& callback) { + int res = SSL_do_handshake(stream->ssl_.get()); + auto task_runner = stream->task_runner_; + if (res == 1) { + return task_runner->PostDelayedTask( + FROM_HERE, base::Bind(callback, base::Passed(&stream), nullptr), {}); } - return false; + + res = SSL_get_error(stream->ssl_.get(), res); + + if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE) { + return RetryAsyncTask( + task_runner, FROM_HERE, + base::Bind(&SSLStream::DoHandshake, base::Passed(&stream), callback)); + } + + ErrorPtr error; + AddSslError(&error, FROM_HERE, "handshake_failed", res); + task_runner->PostDelayedTask( + FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); } } // namespace examples
diff --git a/libweave/examples/provider/ssl_stream.h b/libweave/examples/provider/ssl_stream.h index ca3731b..f1ef4a6 100644 --- a/libweave/examples/provider/ssl_stream.h +++ b/libweave/examples/provider/ssl_stream.h
@@ -8,6 +8,7 @@ #include <openssl/ssl.h> #include <base/memory/weak_ptr.h> +#include <weave/provider/network.h> #include <weave/stream.h> namespace weave { @@ -20,8 +21,6 @@ class SSLStream : public Stream { public: - explicit SSLStream(provider::TaskRunner* task_runner); - ~SSLStream() override; void Read(void* buffer, @@ -34,14 +33,35 @@ void CancelPendingOperations() override; - bool Init(const std::string& host, uint16_t port); + static void Connect(provider::TaskRunner* task_runner, + const std::string& host, + uint16_t port, + const provider::Network::OpenSslSocketCallback& callback); private: - void RunDelayedTask(const base::Closure& task); + struct SslDeleter { + void operator()(BIO* bio) const; + void operator()(SSL* ssl) const; + void operator()(SSL_CTX* ctx) const; + }; + + SSLStream(provider::TaskRunner* task_runner, + std::unique_ptr<BIO, SslDeleter> stream_bio); + + static void ConnectBio( + std::unique_ptr<SSLStream> stream, + const provider::Network::OpenSslSocketCallback& callback); + static void DoHandshake( + std::unique_ptr<SSLStream> stream, + const provider::Network::OpenSslSocketCallback& callback); + + // Send task to this method with WeakPtr if callback should not be executed + // after SSLStream is destroyed. + void RunTask(const base::Closure& task); provider::TaskRunner* task_runner_{nullptr}; - std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free}; - std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free}; + std::unique_ptr<SSL_CTX, SslDeleter> ctx_; + std::unique_ptr<SSL, SslDeleter> ssl_; base::WeakPtrFactory<SSLStream> weak_ptr_factory_{this}; };