| // Copyright 2015 The Weave Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "examples/provider/ssl_stream.h" |
| |
| #include <openssl/err.h> |
| |
| #include <base/bind.h> |
| #include <base/bind_helpers.h> |
| #include <weave/provider/task_runner.h> |
| |
| namespace weave { |
| namespace examples { |
| |
| namespace { |
| |
| void AddSslError(ErrorPtr* error, |
| const tracked_objects::Location& location, |
| const std::string& error_code, |
| unsigned long ssl_error_code) { |
| ERR_load_BIO_strings(); |
| SSL_load_error_strings(); |
| Error::AddToPrintf(error, location, error_code, "%s: %s", |
| ERR_lib_error_string(ssl_error_code), |
| ERR_reason_error_string(ssl_error_code)); |
| } |
| |
| void RetryAsyncTask(provider::TaskRunner* task_runner, |
| const tracked_objects::Location& location, |
| const base::Closure& task) { |
| task_runner->PostDelayedTask(FROM_HERE, task, |
| base::TimeDelta::FromMilliseconds(100)); |
| } |
| |
| } // namespace |
| |
| void SSLStream::SslDeleter::operator()(BIO* bio) const { |
| BIO_free(bio); |
| } |
| |
| void SSLStream::SslDeleter::operator()(SSL* ssl) const { |
| SSL_free(ssl); |
| } |
| |
| void SSLStream::SslDeleter::operator()(SSL_CTX* ctx) const { |
| SSL_CTX_free(ctx); |
| } |
| |
| SSLStream::SSLStream(provider::TaskRunner* task_runner, |
| std::unique_ptr<BIO, SslDeleter> stream_bio) |
| : task_runner_{task_runner} { |
| ctx_.reset(SSL_CTX_new(TLSv1_2_client_method())); |
| CHECK(ctx_); |
| ssl_.reset(SSL_new(ctx_.get())); |
| |
| SSL_set_bio(ssl_.get(), stream_bio.get(), stream_bio.get()); |
| stream_bio.release(); // Owned by ssl now. |
| SSL_set_connect_state(ssl_.get()); |
| } |
| |
| SSLStream::~SSLStream() { |
| CancelPendingOperations(); |
| } |
| |
| void SSLStream::RunTask(const base::Closure& task) { |
| task.Run(); |
| } |
| |
| void SSLStream::Read(void* buffer, |
| size_t size_to_read, |
| const ReadCallback& callback) { |
| int res = SSL_read(ssl_.get(), buffer, size_to_read); |
| if (res > 0) { |
| task_runner_->PostDelayedTask( |
| FROM_HERE, |
| base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), |
| base::Bind(callback, res, nullptr)), |
| {}); |
| return; |
| } |
| |
| int err = SSL_get_error(ssl_.get(), res); |
| |
| if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { |
| return RetryAsyncTask( |
| task_runner_, FROM_HERE, |
| base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer, |
| size_to_read, callback)); |
| } |
| |
| ErrorPtr weave_error; |
| AddSslError(&weave_error, FROM_HERE, "read_failed", err); |
| return task_runner_->PostDelayedTask( |
| FROM_HERE, |
| base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), |
| base::Bind(callback, 0, base::Passed(&weave_error))), |
| {}); |
| } |
| |
| void SSLStream::Write(const void* buffer, |
| size_t size_to_write, |
| const WriteCallback& callback) { |
| int res = SSL_write(ssl_.get(), buffer, size_to_write); |
| if (res > 0) { |
| buffer = static_cast<const char*>(buffer) + res; |
| size_to_write -= res; |
| if (size_to_write == 0) { |
| return task_runner_->PostDelayedTask( |
| FROM_HERE, |
| base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), |
| base::Bind(callback, nullptr)), |
| {}); |
| } |
| |
| return RetryAsyncTask( |
| task_runner_, FROM_HERE, |
| base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer, |
| size_to_write, callback)); |
| } |
| |
| int err = SSL_get_error(ssl_.get(), res); |
| |
| if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { |
| return RetryAsyncTask( |
| task_runner_, FROM_HERE, |
| base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer, |
| size_to_write, callback)); |
| } |
| |
| ErrorPtr weave_error; |
| AddSslError(&weave_error, FROM_HERE, "write_failed", err); |
| task_runner_->PostDelayedTask( |
| FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), |
| base::Bind(callback, base::Passed(&weave_error))), |
| {}); |
| } |
| |
| void SSLStream::CancelPendingOperations() { |
| weak_ptr_factory_.InvalidateWeakPtrs(); |
| } |
| |
| void SSLStream::Connect( |
| provider::TaskRunner* task_runner, |
| const std::string& host, |
| uint16_t port, |
| const provider::Network::OpenSslSocketCallback& callback) { |
| SSL_library_init(); |
| |
| char end_point[255]; |
| snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port); |
| |
| std::unique_ptr<BIO, SslDeleter> stream_bio(BIO_new_connect(end_point)); |
| CHECK(stream_bio); |
| BIO_set_nbio(stream_bio.get(), 1); |
| |
| std::unique_ptr<SSLStream> stream{ |
| new SSLStream{task_runner, std::move(stream_bio)}}; |
| ConnectBio(std::move(stream), callback); |
| } |
| |
| void SSLStream::ConnectBio( |
| std::unique_ptr<SSLStream> stream, |
| const provider::Network::OpenSslSocketCallback& callback) { |
| BIO* bio = SSL_get_rbio(stream->ssl_.get()); |
| if (BIO_do_connect(bio) == 1) |
| return DoHandshake(std::move(stream), callback); |
| |
| auto task_runner = stream->task_runner_; |
| if (BIO_should_retry(bio)) { |
| return RetryAsyncTask( |
| task_runner, FROM_HERE, |
| base::Bind(&SSLStream::ConnectBio, base::Passed(&stream), callback)); |
| } |
| |
| ErrorPtr error; |
| AddSslError(&error, FROM_HERE, "connect_failed", ERR_get_error()); |
| task_runner->PostDelayedTask( |
| FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); |
| } |
| |
| void SSLStream::DoHandshake( |
| std::unique_ptr<SSLStream> stream, |
| const provider::Network::OpenSslSocketCallback& callback) { |
| int res = SSL_do_handshake(stream->ssl_.get()); |
| auto task_runner = stream->task_runner_; |
| if (res == 1) { |
| return task_runner->PostDelayedTask( |
| FROM_HERE, base::Bind(callback, base::Passed(&stream), nullptr), {}); |
| } |
| |
| res = SSL_get_error(stream->ssl_.get(), res); |
| |
| if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE) { |
| return RetryAsyncTask( |
| task_runner, FROM_HERE, |
| base::Bind(&SSLStream::DoHandshake, base::Passed(&stream), callback)); |
| } |
| |
| ErrorPtr error; |
| AddSslError(&error, FROM_HERE, "handshake_failed", res); |
| task_runner->PostDelayedTask( |
| FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); |
| } |
| |
| } // namespace examples |
| } // namespace weave |