Make SSLStream initialization non-blocking

Existing initialization solution was blocking.

BUG:24961367

Change-Id: Iec5125013161cfd0c50ede78efdbdb1cb42ebd64
Reviewed-on: https://weave-review.googlesource.com/1362
Reviewed-by: Alex Vakulenko <avakulenko@google.com>
diff --git a/libweave/examples/provider/event_network.cc b/libweave/examples/provider/event_network.cc
index 82165be..d3a3391 100644
--- a/libweave/examples/provider/event_network.cc
+++ b/libweave/examples/provider/event_network.cc
@@ -100,20 +100,7 @@
 void EventNetworkImpl::OpenSslSocket(const std::string& host,
                                      uint16_t port,
                                      const OpenSslSocketCallback& callback) {
-  // Connect to SSL port instead of upgrading to TLS.
-  std::unique_ptr<SSLStream> tls_stream{new SSLStream{task_runner_}};
-
-  if (tls_stream->Init(host, port)) {
-    task_runner_->PostDelayedTask(
-        FROM_HERE, base::Bind(callback, base::Passed(&tls_stream), nullptr),
-        {});
-  } else {
-    ErrorPtr error;
-    Error::AddTo(&error, FROM_HERE, "tls", "tls_init_failed",
-                 "Failed to initialize TLS stream.");
-    task_runner_->PostDelayedTask(
-        FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
-  }
+  SSLStream::Connect(task_runner_, host, port, callback);
 }
 
 }  // namespace examples
diff --git a/libweave/examples/provider/ssl_stream.cc b/libweave/examples/provider/ssl_stream.cc
index d3cbcfd..eea28c0 100644
--- a/libweave/examples/provider/ssl_stream.cc
+++ b/libweave/examples/provider/ssl_stream.cc
@@ -4,29 +4,66 @@
 
 #include "examples/provider/ssl_stream.h"
 
+#include <openssl/err.h>
+
 #include <base/bind.h>
+#include <base/bind_helpers.h>
 #include <weave/provider/task_runner.h>
 
 namespace weave {
 namespace examples {
 
 namespace {
-int GetSSLError(const SSL* ssl, int ret) {
+
+void AddSslError(ErrorPtr* error,
+                 const tracked_objects::Location& location,
+                 const std::string& error_code,
+                 unsigned long ssl_error_code) {
+  ERR_load_BIO_strings();
   SSL_load_error_strings();
-  return SSL_get_error(ssl, ret);
+  Error::AddToPrintf(error, location, "ssl_stream", error_code, "%s: %s",
+                     ERR_lib_error_string(ssl_error_code),
+                     ERR_reason_error_string(ssl_error_code));
 }
+
+void RetryAsyncTask(provider::TaskRunner* task_runner,
+                    const tracked_objects::Location& location,
+                    const base::Closure& task) {
+  task_runner->PostDelayedTask(FROM_HERE, task,
+                               base::TimeDelta::FromMilliseconds(100));
+}
+
 }  // namespace
 
-SSLStream::SSLStream(provider::TaskRunner* task_runner)
+void SSLStream::SslDeleter::operator()(BIO* bio) const {
+  BIO_free(bio);
+}
+
+void SSLStream::SslDeleter::operator()(SSL* ssl) const {
+  SSL_free(ssl);
+}
+
+void SSLStream::SslDeleter::operator()(SSL_CTX* ctx) const {
+  SSL_CTX_free(ctx);
+}
+
+SSLStream::SSLStream(provider::TaskRunner* task_runner,
+                     std::unique_ptr<BIO, SslDeleter> stream_bio)
     : task_runner_{task_runner} {
-  SSL_library_init();
+  ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
+  CHECK(ctx_);
+  ssl_.reset(SSL_new(ctx_.get()));
+
+  SSL_set_bio(ssl_.get(), stream_bio.get(), stream_bio.get());
+  stream_bio.release();  // Owned by ssl now.
+  SSL_set_connect_state(ssl_.get());
 }
 
 SSLStream::~SSLStream() {
   CancelPendingOperations();
 }
 
-void SSLStream::RunDelayedTask(const base::Closure& task) {
+void SSLStream::RunTask(const base::Closure& task) {
   task.Run();
 }
 
@@ -37,31 +74,28 @@
   if (res > 0) {
     task_runner_->PostDelayedTask(
         FROM_HERE,
-        base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(),
+        base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
                    base::Bind(callback, res, nullptr)),
         {});
     return;
   }
 
-  int err = GetSSLError(ssl_.get(), res);
+  int err = SSL_get_error(ssl_.get(), res);
 
   if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
-    task_runner_->PostDelayedTask(
-        FROM_HERE, base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(),
-                              buffer, size_to_read, callback),
-        base::TimeDelta::FromSeconds(1));
-    return;
+    return RetryAsyncTask(
+        task_runner_, FROM_HERE,
+        base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer,
+                   size_to_read, callback));
   }
 
   ErrorPtr weave_error;
-  Error::AddTo(&weave_error, FROM_HERE, "ssl", "socket_read_failed",
-               "SSL error");
-  task_runner_->PostDelayedTask(
+  AddSslError(&weave_error, FROM_HERE, "read_failed", err);
+  return task_runner_->PostDelayedTask(
       FROM_HERE,
-      base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(),
+      base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
                  base::Bind(callback, 0, base::Passed(&weave_error))),
       {});
-  return;
 }
 
 void SSLStream::Write(const void* buffer,
@@ -72,81 +106,101 @@
     buffer = static_cast<const char*>(buffer) + res;
     size_to_write -= res;
     if (size_to_write == 0) {
-      task_runner_->PostDelayedTask(
+      return task_runner_->PostDelayedTask(
           FROM_HERE,
-          base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(),
+          base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
                      base::Bind(callback, nullptr)),
           {});
-      return;
     }
 
-    task_runner_->PostDelayedTask(
-        FROM_HERE, base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(),
-                              buffer, size_to_write, callback),
-        base::TimeDelta::FromSeconds(1));
-
-    return;
+    return RetryAsyncTask(
+        task_runner_, FROM_HERE,
+        base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
+                   size_to_write, callback));
   }
 
-  int err = GetSSLError(ssl_.get(), res);
+  int err = SSL_get_error(ssl_.get(), res);
 
   if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
-    task_runner_->PostDelayedTask(
-        FROM_HERE, base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(),
-                              buffer, size_to_write, callback),
-        base::TimeDelta::FromSeconds(1));
-    return;
+    return RetryAsyncTask(
+        task_runner_, FROM_HERE,
+        base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
+                   size_to_write, callback));
   }
 
   ErrorPtr weave_error;
-  Error::AddTo(&weave_error, FROM_HERE, "ssl", "socket_write_failed",
-               "SSL error");
+  AddSslError(&weave_error, FROM_HERE, "write_failed", err);
   task_runner_->PostDelayedTask(
-      FROM_HERE,
-      base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(),
-                 base::Bind(callback, base::Passed(&weave_error))),
+      FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
+                            base::Bind(callback, base::Passed(&weave_error))),
       {});
-  return;
 }
 
 void SSLStream::CancelPendingOperations() {
   weak_ptr_factory_.InvalidateWeakPtrs();
 }
 
-bool SSLStream::Init(const std::string& host, uint16_t port) {
-  ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
-  CHECK(ctx_);
-  ssl_.reset(SSL_new(ctx_.get()));
+void SSLStream::Connect(
+    provider::TaskRunner* task_runner,
+    const std::string& host,
+    uint16_t port,
+    const provider::Network::OpenSslSocketCallback& callback) {
+  SSL_library_init();
 
   char end_point[255];
   snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port);
-  BIO* stream_bio = BIO_new_connect(end_point);
+
+  std::unique_ptr<BIO, SslDeleter> stream_bio(BIO_new_connect(end_point));
   CHECK(stream_bio);
-  BIO_set_nbio(stream_bio, 1);
+  BIO_set_nbio(stream_bio.get(), 1);
 
-  while (BIO_do_connect(stream_bio) != 1) {
-    CHECK(BIO_should_retry(stream_bio));
-    sleep(1);
+  std::unique_ptr<SSLStream> stream{
+      new SSLStream{task_runner, std::move(stream_bio)}};
+  ConnectBio(std::move(stream), callback);
+}
+
+void SSLStream::ConnectBio(
+    std::unique_ptr<SSLStream> stream,
+    const provider::Network::OpenSslSocketCallback& callback) {
+  BIO* bio = SSL_get_rbio(stream->ssl_.get());
+  if (BIO_do_connect(bio) == 1)
+    return DoHandshake(std::move(stream), callback);
+
+  auto task_runner = stream->task_runner_;
+  if (BIO_should_retry(bio)) {
+    return RetryAsyncTask(
+        task_runner, FROM_HERE,
+        base::Bind(&SSLStream::ConnectBio, base::Passed(&stream), callback));
   }
 
-  SSL_set_bio(ssl_.get(), stream_bio, stream_bio);
-  SSL_set_connect_state(ssl_.get());
+  ErrorPtr error;
+  AddSslError(&error, FROM_HERE, "connect_failed", ERR_get_error());
+  task_runner->PostDelayedTask(
+      FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
+}
 
-  for (;;) {
-    int res = SSL_do_handshake(ssl_.get());
-    if (res) {
-      return true;
-    }
-
-    res = GetSSLError(ssl_.get(), res);
-
-    if (res != SSL_ERROR_WANT_READ || res != SSL_ERROR_WANT_WRITE) {
-      return false;
-    }
-
-    sleep(1);
+void SSLStream::DoHandshake(
+    std::unique_ptr<SSLStream> stream,
+    const provider::Network::OpenSslSocketCallback& callback) {
+  int res = SSL_do_handshake(stream->ssl_.get());
+  auto task_runner = stream->task_runner_;
+  if (res == 1) {
+    return task_runner->PostDelayedTask(
+        FROM_HERE, base::Bind(callback, base::Passed(&stream), nullptr), {});
   }
-  return false;
+
+  res = SSL_get_error(stream->ssl_.get(), res);
+
+  if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE) {
+    return RetryAsyncTask(
+        task_runner, FROM_HERE,
+        base::Bind(&SSLStream::DoHandshake, base::Passed(&stream), callback));
+  }
+
+  ErrorPtr error;
+  AddSslError(&error, FROM_HERE, "handshake_failed", res);
+  task_runner->PostDelayedTask(
+      FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
 }
 
 }  // namespace examples
diff --git a/libweave/examples/provider/ssl_stream.h b/libweave/examples/provider/ssl_stream.h
index ca3731b..f1ef4a6 100644
--- a/libweave/examples/provider/ssl_stream.h
+++ b/libweave/examples/provider/ssl_stream.h
@@ -8,6 +8,7 @@
 #include <openssl/ssl.h>
 
 #include <base/memory/weak_ptr.h>
+#include <weave/provider/network.h>
 #include <weave/stream.h>
 
 namespace weave {
@@ -20,8 +21,6 @@
 
 class SSLStream : public Stream {
  public:
-  explicit SSLStream(provider::TaskRunner* task_runner);
-
   ~SSLStream() override;
 
   void Read(void* buffer,
@@ -34,14 +33,35 @@
 
   void CancelPendingOperations() override;
 
-  bool Init(const std::string& host, uint16_t port);
+  static void Connect(provider::TaskRunner* task_runner,
+                      const std::string& host,
+                      uint16_t port,
+                      const provider::Network::OpenSslSocketCallback& callback);
 
  private:
-  void RunDelayedTask(const base::Closure& task);
+  struct SslDeleter {
+    void operator()(BIO* bio) const;
+    void operator()(SSL* ssl) const;
+    void operator()(SSL_CTX* ctx) const;
+  };
+
+  SSLStream(provider::TaskRunner* task_runner,
+            std::unique_ptr<BIO, SslDeleter> stream_bio);
+
+  static void ConnectBio(
+      std::unique_ptr<SSLStream> stream,
+      const provider::Network::OpenSslSocketCallback& callback);
+  static void DoHandshake(
+      std::unique_ptr<SSLStream> stream,
+      const provider::Network::OpenSslSocketCallback& callback);
+
+  // Send task to this method with WeakPtr if callback should not be executed
+  // after SSLStream is destroyed.
+  void RunTask(const base::Closure& task);
 
   provider::TaskRunner* task_runner_{nullptr};
-  std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free};
-  std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free};
+  std::unique_ptr<SSL_CTX, SslDeleter> ctx_;
+  std::unique_ptr<SSL, SslDeleter> ssl_;
 
   base::WeakPtrFactory<SSLStream> weak_ptr_factory_{this};
 };