// Copyright 2015 The Weave Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "examples/provider/ssl_stream.h"

#include <openssl/err.h>

#include <base/bind.h>
#include <base/bind_helpers.h>
#include <weave/provider/task_runner.h>

namespace weave {
namespace examples {

namespace {

void AddSslError(ErrorPtr* error,
                 const tracked_objects::Location& location,
                 const std::string& error_code,
                 unsigned long ssl_error_code) {
  ERR_load_BIO_strings();
  SSL_load_error_strings();
  Error::AddToPrintf(error, location, "ssl_stream", error_code, "%s: %s",
                     ERR_lib_error_string(ssl_error_code),
                     ERR_reason_error_string(ssl_error_code));
}

void RetryAsyncTask(provider::TaskRunner* task_runner,
                    const tracked_objects::Location& location,
                    const base::Closure& task) {
  task_runner->PostDelayedTask(FROM_HERE, task,
                               base::TimeDelta::FromMilliseconds(100));
}

}  // namespace

void SSLStream::SslDeleter::operator()(BIO* bio) const {
  BIO_free(bio);
}

void SSLStream::SslDeleter::operator()(SSL* ssl) const {
  SSL_free(ssl);
}

void SSLStream::SslDeleter::operator()(SSL_CTX* ctx) const {
  SSL_CTX_free(ctx);
}

SSLStream::SSLStream(provider::TaskRunner* task_runner,
                     std::unique_ptr<BIO, SslDeleter> stream_bio)
    : task_runner_{task_runner} {
  ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
  CHECK(ctx_);
  ssl_.reset(SSL_new(ctx_.get()));

  SSL_set_bio(ssl_.get(), stream_bio.get(), stream_bio.get());
  stream_bio.release();  // Owned by ssl now.
  SSL_set_connect_state(ssl_.get());
}

SSLStream::~SSLStream() {
  CancelPendingOperations();
}

void SSLStream::RunTask(const base::Closure& task) {
  task.Run();
}

void SSLStream::Read(void* buffer,
                     size_t size_to_read,
                     const ReadCallback& callback) {
  int res = SSL_read(ssl_.get(), buffer, size_to_read);
  if (res > 0) {
    task_runner_->PostDelayedTask(
        FROM_HERE,
        base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
                   base::Bind(callback, res, nullptr)),
        {});
    return;
  }

  int err = SSL_get_error(ssl_.get(), res);

  if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
    return RetryAsyncTask(
        task_runner_, FROM_HERE,
        base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer,
                   size_to_read, callback));
  }

  ErrorPtr weave_error;
  AddSslError(&weave_error, FROM_HERE, "read_failed", err);
  return task_runner_->PostDelayedTask(
      FROM_HERE,
      base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
                 base::Bind(callback, 0, base::Passed(&weave_error))),
      {});
}

void SSLStream::Write(const void* buffer,
                      size_t size_to_write,
                      const WriteCallback& callback) {
  int res = SSL_write(ssl_.get(), buffer, size_to_write);
  if (res > 0) {
    buffer = static_cast<const char*>(buffer) + res;
    size_to_write -= res;
    if (size_to_write == 0) {
      return task_runner_->PostDelayedTask(
          FROM_HERE,
          base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
                     base::Bind(callback, nullptr)),
          {});
    }

    return RetryAsyncTask(
        task_runner_, FROM_HERE,
        base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
                   size_to_write, callback));
  }

  int err = SSL_get_error(ssl_.get(), res);

  if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
    return RetryAsyncTask(
        task_runner_, FROM_HERE,
        base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
                   size_to_write, callback));
  }

  ErrorPtr weave_error;
  AddSslError(&weave_error, FROM_HERE, "write_failed", err);
  task_runner_->PostDelayedTask(
      FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
                            base::Bind(callback, base::Passed(&weave_error))),
      {});
}

void SSLStream::CancelPendingOperations() {
  weak_ptr_factory_.InvalidateWeakPtrs();
}

void SSLStream::Connect(
    provider::TaskRunner* task_runner,
    const std::string& host,
    uint16_t port,
    const provider::Network::OpenSslSocketCallback& callback) {
  SSL_library_init();

  char end_point[255];
  snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port);

  std::unique_ptr<BIO, SslDeleter> stream_bio(BIO_new_connect(end_point));
  CHECK(stream_bio);
  BIO_set_nbio(stream_bio.get(), 1);

  std::unique_ptr<SSLStream> stream{
      new SSLStream{task_runner, std::move(stream_bio)}};
  ConnectBio(std::move(stream), callback);
}

void SSLStream::ConnectBio(
    std::unique_ptr<SSLStream> stream,
    const provider::Network::OpenSslSocketCallback& callback) {
  BIO* bio = SSL_get_rbio(stream->ssl_.get());
  if (BIO_do_connect(bio) == 1)
    return DoHandshake(std::move(stream), callback);

  auto task_runner = stream->task_runner_;
  if (BIO_should_retry(bio)) {
    return RetryAsyncTask(
        task_runner, FROM_HERE,
        base::Bind(&SSLStream::ConnectBio, base::Passed(&stream), callback));
  }

  ErrorPtr error;
  AddSslError(&error, FROM_HERE, "connect_failed", ERR_get_error());
  task_runner->PostDelayedTask(
      FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
}

void SSLStream::DoHandshake(
    std::unique_ptr<SSLStream> stream,
    const provider::Network::OpenSslSocketCallback& callback) {
  int res = SSL_do_handshake(stream->ssl_.get());
  auto task_runner = stream->task_runner_;
  if (res == 1) {
    return task_runner->PostDelayedTask(
        FROM_HERE, base::Bind(callback, base::Passed(&stream), nullptr), {});
  }

  res = SSL_get_error(stream->ssl_.get(), res);

  if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE) {
    return RetryAsyncTask(
        task_runner, FROM_HERE,
        base::Bind(&SSLStream::DoHandshake, base::Passed(&stream), callback));
  }

  ErrorPtr error;
  AddSslError(&error, FROM_HERE, "handshake_failed", res);
  task_runner->PostDelayedTask(
      FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
}

}  // namespace examples
}  // namespace weave
