Make SSLStream initialization non-blocking
Existing initialization solution was blocking.
BUG:24961367
Change-Id: Iec5125013161cfd0c50ede78efdbdb1cb42ebd64
Reviewed-on: https://weave-review.googlesource.com/1362
Reviewed-by: Alex Vakulenko <avakulenko@google.com>
diff --git a/libweave/examples/provider/event_network.cc b/libweave/examples/provider/event_network.cc
index 82165be..d3a3391 100644
--- a/libweave/examples/provider/event_network.cc
+++ b/libweave/examples/provider/event_network.cc
@@ -100,20 +100,7 @@
void EventNetworkImpl::OpenSslSocket(const std::string& host,
uint16_t port,
const OpenSslSocketCallback& callback) {
- // Connect to SSL port instead of upgrading to TLS.
- std::unique_ptr<SSLStream> tls_stream{new SSLStream{task_runner_}};
-
- if (tls_stream->Init(host, port)) {
- task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(callback, base::Passed(&tls_stream), nullptr),
- {});
- } else {
- ErrorPtr error;
- Error::AddTo(&error, FROM_HERE, "tls", "tls_init_failed",
- "Failed to initialize TLS stream.");
- task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
- }
+ SSLStream::Connect(task_runner_, host, port, callback);
}
} // namespace examples
diff --git a/libweave/examples/provider/ssl_stream.cc b/libweave/examples/provider/ssl_stream.cc
index d3cbcfd..eea28c0 100644
--- a/libweave/examples/provider/ssl_stream.cc
+++ b/libweave/examples/provider/ssl_stream.cc
@@ -4,29 +4,66 @@
#include "examples/provider/ssl_stream.h"
+#include <openssl/err.h>
+
#include <base/bind.h>
+#include <base/bind_helpers.h>
#include <weave/provider/task_runner.h>
namespace weave {
namespace examples {
namespace {
-int GetSSLError(const SSL* ssl, int ret) {
+
+void AddSslError(ErrorPtr* error,
+ const tracked_objects::Location& location,
+ const std::string& error_code,
+ unsigned long ssl_error_code) {
+ ERR_load_BIO_strings();
SSL_load_error_strings();
- return SSL_get_error(ssl, ret);
+ Error::AddToPrintf(error, location, "ssl_stream", error_code, "%s: %s",
+ ERR_lib_error_string(ssl_error_code),
+ ERR_reason_error_string(ssl_error_code));
}
+
+void RetryAsyncTask(provider::TaskRunner* task_runner,
+ const tracked_objects::Location& location,
+ const base::Closure& task) {
+ task_runner->PostDelayedTask(FROM_HERE, task,
+ base::TimeDelta::FromMilliseconds(100));
+}
+
} // namespace
-SSLStream::SSLStream(provider::TaskRunner* task_runner)
+void SSLStream::SslDeleter::operator()(BIO* bio) const {
+ BIO_free(bio);
+}
+
+void SSLStream::SslDeleter::operator()(SSL* ssl) const {
+ SSL_free(ssl);
+}
+
+void SSLStream::SslDeleter::operator()(SSL_CTX* ctx) const {
+ SSL_CTX_free(ctx);
+}
+
+SSLStream::SSLStream(provider::TaskRunner* task_runner,
+ std::unique_ptr<BIO, SslDeleter> stream_bio)
: task_runner_{task_runner} {
- SSL_library_init();
+ ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
+ CHECK(ctx_);
+ ssl_.reset(SSL_new(ctx_.get()));
+
+ SSL_set_bio(ssl_.get(), stream_bio.get(), stream_bio.get());
+ stream_bio.release(); // Owned by ssl now.
+ SSL_set_connect_state(ssl_.get());
}
SSLStream::~SSLStream() {
CancelPendingOperations();
}
-void SSLStream::RunDelayedTask(const base::Closure& task) {
+void SSLStream::RunTask(const base::Closure& task) {
task.Run();
}
@@ -37,31 +74,28 @@
if (res > 0) {
task_runner_->PostDelayedTask(
FROM_HERE,
- base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(),
+ base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
base::Bind(callback, res, nullptr)),
{});
return;
}
- int err = GetSSLError(ssl_.get(), res);
+ int err = SSL_get_error(ssl_.get(), res);
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
- task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(),
- buffer, size_to_read, callback),
- base::TimeDelta::FromSeconds(1));
- return;
+ return RetryAsyncTask(
+ task_runner_, FROM_HERE,
+ base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer,
+ size_to_read, callback));
}
ErrorPtr weave_error;
- Error::AddTo(&weave_error, FROM_HERE, "ssl", "socket_read_failed",
- "SSL error");
- task_runner_->PostDelayedTask(
+ AddSslError(&weave_error, FROM_HERE, "read_failed", err);
+ return task_runner_->PostDelayedTask(
FROM_HERE,
- base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(),
+ base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
base::Bind(callback, 0, base::Passed(&weave_error))),
{});
- return;
}
void SSLStream::Write(const void* buffer,
@@ -72,81 +106,101 @@
buffer = static_cast<const char*>(buffer) + res;
size_to_write -= res;
if (size_to_write == 0) {
- task_runner_->PostDelayedTask(
+ return task_runner_->PostDelayedTask(
FROM_HERE,
- base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(),
+ base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
base::Bind(callback, nullptr)),
{});
- return;
}
- task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(),
- buffer, size_to_write, callback),
- base::TimeDelta::FromSeconds(1));
-
- return;
+ return RetryAsyncTask(
+ task_runner_, FROM_HERE,
+ base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
+ size_to_write, callback));
}
- int err = GetSSLError(ssl_.get(), res);
+ int err = SSL_get_error(ssl_.get(), res);
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
- task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(),
- buffer, size_to_write, callback),
- base::TimeDelta::FromSeconds(1));
- return;
+ return RetryAsyncTask(
+ task_runner_, FROM_HERE,
+ base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
+ size_to_write, callback));
}
ErrorPtr weave_error;
- Error::AddTo(&weave_error, FROM_HERE, "ssl", "socket_write_failed",
- "SSL error");
+ AddSslError(&weave_error, FROM_HERE, "write_failed", err);
task_runner_->PostDelayedTask(
- FROM_HERE,
- base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(),
- base::Bind(callback, base::Passed(&weave_error))),
+ FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
+ base::Bind(callback, base::Passed(&weave_error))),
{});
- return;
}
void SSLStream::CancelPendingOperations() {
weak_ptr_factory_.InvalidateWeakPtrs();
}
-bool SSLStream::Init(const std::string& host, uint16_t port) {
- ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
- CHECK(ctx_);
- ssl_.reset(SSL_new(ctx_.get()));
+void SSLStream::Connect(
+ provider::TaskRunner* task_runner,
+ const std::string& host,
+ uint16_t port,
+ const provider::Network::OpenSslSocketCallback& callback) {
+ SSL_library_init();
char end_point[255];
snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port);
- BIO* stream_bio = BIO_new_connect(end_point);
+
+ std::unique_ptr<BIO, SslDeleter> stream_bio(BIO_new_connect(end_point));
CHECK(stream_bio);
- BIO_set_nbio(stream_bio, 1);
+ BIO_set_nbio(stream_bio.get(), 1);
- while (BIO_do_connect(stream_bio) != 1) {
- CHECK(BIO_should_retry(stream_bio));
- sleep(1);
+ std::unique_ptr<SSLStream> stream{
+ new SSLStream{task_runner, std::move(stream_bio)}};
+ ConnectBio(std::move(stream), callback);
+}
+
+void SSLStream::ConnectBio(
+ std::unique_ptr<SSLStream> stream,
+ const provider::Network::OpenSslSocketCallback& callback) {
+ BIO* bio = SSL_get_rbio(stream->ssl_.get());
+ if (BIO_do_connect(bio) == 1)
+ return DoHandshake(std::move(stream), callback);
+
+ auto task_runner = stream->task_runner_;
+ if (BIO_should_retry(bio)) {
+ return RetryAsyncTask(
+ task_runner, FROM_HERE,
+ base::Bind(&SSLStream::ConnectBio, base::Passed(&stream), callback));
}
- SSL_set_bio(ssl_.get(), stream_bio, stream_bio);
- SSL_set_connect_state(ssl_.get());
+ ErrorPtr error;
+ AddSslError(&error, FROM_HERE, "connect_failed", ERR_get_error());
+ task_runner->PostDelayedTask(
+ FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
+}
- for (;;) {
- int res = SSL_do_handshake(ssl_.get());
- if (res) {
- return true;
- }
-
- res = GetSSLError(ssl_.get(), res);
-
- if (res != SSL_ERROR_WANT_READ || res != SSL_ERROR_WANT_WRITE) {
- return false;
- }
-
- sleep(1);
+void SSLStream::DoHandshake(
+ std::unique_ptr<SSLStream> stream,
+ const provider::Network::OpenSslSocketCallback& callback) {
+ int res = SSL_do_handshake(stream->ssl_.get());
+ auto task_runner = stream->task_runner_;
+ if (res == 1) {
+ return task_runner->PostDelayedTask(
+ FROM_HERE, base::Bind(callback, base::Passed(&stream), nullptr), {});
}
- return false;
+
+ res = SSL_get_error(stream->ssl_.get(), res);
+
+ if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE) {
+ return RetryAsyncTask(
+ task_runner, FROM_HERE,
+ base::Bind(&SSLStream::DoHandshake, base::Passed(&stream), callback));
+ }
+
+ ErrorPtr error;
+ AddSslError(&error, FROM_HERE, "handshake_failed", res);
+ task_runner->PostDelayedTask(
+ FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
}
} // namespace examples
diff --git a/libweave/examples/provider/ssl_stream.h b/libweave/examples/provider/ssl_stream.h
index ca3731b..f1ef4a6 100644
--- a/libweave/examples/provider/ssl_stream.h
+++ b/libweave/examples/provider/ssl_stream.h
@@ -8,6 +8,7 @@
#include <openssl/ssl.h>
#include <base/memory/weak_ptr.h>
+#include <weave/provider/network.h>
#include <weave/stream.h>
namespace weave {
@@ -20,8 +21,6 @@
class SSLStream : public Stream {
public:
- explicit SSLStream(provider::TaskRunner* task_runner);
-
~SSLStream() override;
void Read(void* buffer,
@@ -34,14 +33,35 @@
void CancelPendingOperations() override;
- bool Init(const std::string& host, uint16_t port);
+ static void Connect(provider::TaskRunner* task_runner,
+ const std::string& host,
+ uint16_t port,
+ const provider::Network::OpenSslSocketCallback& callback);
private:
- void RunDelayedTask(const base::Closure& task);
+ struct SslDeleter {
+ void operator()(BIO* bio) const;
+ void operator()(SSL* ssl) const;
+ void operator()(SSL_CTX* ctx) const;
+ };
+
+ SSLStream(provider::TaskRunner* task_runner,
+ std::unique_ptr<BIO, SslDeleter> stream_bio);
+
+ static void ConnectBio(
+ std::unique_ptr<SSLStream> stream,
+ const provider::Network::OpenSslSocketCallback& callback);
+ static void DoHandshake(
+ std::unique_ptr<SSLStream> stream,
+ const provider::Network::OpenSslSocketCallback& callback);
+
+ // Send task to this method with WeakPtr if callback should not be executed
+ // after SSLStream is destroyed.
+ void RunTask(const base::Closure& task);
provider::TaskRunner* task_runner_{nullptr};
- std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free};
- std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free};
+ std::unique_ptr<SSL_CTX, SslDeleter> ctx_;
+ std::unique_ptr<SSL, SslDeleter> ssl_;
base::WeakPtrFactory<SSLStream> weak_ptr_factory_{this};
};