Switch from starttls to SSL sockets
SSL sockets simplifies both libweave and device side.
Also removed Flush and Close method, as not needed with current design.
BUG=b:23529805
Change-Id: I6e2a5522f8a55d20b0ed6363a9955d37505d2c06
diff --git a/libweave/examples/ubuntu/network_manager.cc b/libweave/examples/ubuntu/network_manager.cc
index 48c817c..ab1d521 100644
--- a/libweave/examples/ubuntu/network_manager.cc
+++ b/libweave/examples/ubuntu/network_manager.cc
@@ -39,160 +39,11 @@
NOTREACHED();
}
-class SocketStream : public Stream {
- public:
- explicit SocketStream(TaskRunner* task_runner) : task_runner_{task_runner} {}
-
- ~SocketStream() { CloseBlocking(nullptr); }
-
- void RunDelayedTask(const base::Closure& success_callback) {
- success_callback.Run();
- }
-
- bool ReadAsync(void* buffer,
- size_t size_to_read,
- const base::Callback<void(size_t)>& success_callback,
- const base::Callback<void(const Error*)>& error_callback,
- ErrorPtr* error) {
- if (socket_fd_ < 0) {
- Error::AddTo(error, FROM_HERE, "socket", "invalid_socket",
- strerror(errno));
- return false;
- }
- int size_read = recv(socket_fd_, buffer, size_to_read, MSG_DONTWAIT);
- if (size_read > 0) {
- task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(&SocketStream::RunDelayedTask,
- weak_ptr_factory_.GetWeakPtr(),
- base::Bind(success_callback, size_read)),
- {});
- return true;
- }
- if (errno == EAGAIN || errno == EWOULDBLOCK) {
- task_runner_->PostDelayedTask(
- FROM_HERE,
- base::Bind(base::IgnoreResult(&SocketStream::ReadAsync),
- weak_ptr_factory_.GetWeakPtr(), buffer, size_to_read,
- success_callback, error_callback, nullptr),
- base::TimeDelta::FromMilliseconds(200));
- return true;
- }
-
- ErrorPtr recv_error;
- Error::AddTo(&recv_error, FROM_HERE, "socket", "socket_recv_failed",
- strerror(errno));
- task_runner_->PostDelayedTask(
- FROM_HERE,
- base::Bind(error_callback, base::Owned(recv_error.release())), {});
- return true;
- }
-
- bool WriteAllAsync(const void* buffer,
- size_t size_to_write,
- const base::Closure& success_callback,
- const base::Callback<void(const Error*)>& error_callback,
- ErrorPtr* error) {
- if (socket_fd_ < 0) {
- Error::AddTo(error, FROM_HERE, "socket", "invalid_socket",
- strerror(errno));
- return false;
- }
- const char* buffer_ptr = static_cast<const char*>(buffer);
- do {
- int size_sent = send(socket_fd_, buffer_ptr, size_to_write, 0);
- if (size_sent <= 0) {
- ErrorPtr send_error;
- Error::AddTo(&send_error, FROM_HERE, "socket", "socket_send_failed",
- strerror(errno));
- task_runner_->PostDelayedTask(
- FROM_HERE,
- base::Bind(error_callback, base::Owned(send_error.release())), {});
- // Still true as we return error with callback.
- return true;
- }
- size_to_write -= size_sent;
- buffer_ptr += size_sent;
- } while (size_to_write > 0);
-
- task_runner_->PostDelayedTask(FROM_HERE, success_callback, {});
- return true;
- }
- bool FlushBlocking(ErrorPtr* error) { return true; }
-
- bool CloseBlocking(ErrorPtr* error) {
- weak_ptr_factory_.InvalidateWeakPtrs();
- if (socket_fd_ >= 0) {
- close(socket_fd_);
- socket_fd_ = -1;
- }
- }
-
- void CancelPendingAsyncOperations() {
- weak_ptr_factory_.InvalidateWeakPtrs();
- }
-
- bool Connect(const std::string& host, uint16_t port) {
- std::string service = std::to_string(port);
- addrinfo hints = {0, AF_UNSPEC, SOCK_STREAM};
- addrinfo* result = nullptr;
- if (getaddrinfo(host.c_str(), service.c_str(), &hints, &result)) {
- LOG(ERROR) << "Failed to resolve host name: " << host;
- return false;
- }
- std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> result_deleter{
- result, &freeaddrinfo};
-
- for (const addrinfo* info = result; info != nullptr; info = info->ai_next) {
- socket_fd_ =
- socket(info->ai_family, info->ai_socktype, info->ai_protocol);
- if (socket_fd_ < 0)
- continue;
-
- int flags = fcntl(socket_fd_, F_GETFL, 0);
- if (flags == -1)
- flags = 0;
- fcntl(socket_fd_, F_SETFL, flags | O_NONBLOCK);
-
- LOG(INFO) << "Connecting...";
- if (connect(socket_fd_, info->ai_addr, info->ai_addrlen) == 0)
- break; // Success.
-
- if (errno == EINPROGRESS) {
- fd_set write_fds;
- FD_ZERO(&write_fds);
- FD_SET(socket_fd_, &write_fds);
-
- struct timeval tv;
- tv.tv_sec = 5;
- tv.tv_usec = 0;
-
- int select_ret = select(socket_fd_ + 1, NULL, &write_fds, NULL, &tv);
- if (select_ret != -1 && select_ret != 0) {
- break;
- }
- }
-
- LOG(ERROR) << "Failed to connect";
- CloseBlocking(nullptr);
- }
-
- return socket_fd_ >= 0;
- }
-
- int GetFd() const { return socket_fd_; }
-
- private:
- TaskRunner* task_runner_{nullptr};
- int socket_fd_{-1};
-
- base::WeakPtrFactory<SocketStream> weak_ptr_factory_{this};
-};
-
class SSLStream : public Stream {
public:
explicit SSLStream(TaskRunner* task_runner) : task_runner_{task_runner} {}
- ~SSLStream() { weak_ptr_factory_.InvalidateWeakPtrs(); }
+ ~SSLStream() { CancelPendingAsyncOperations(); }
void RunDelayedTask(const base::Closure& success_callback) {
success_callback.Run();
@@ -289,33 +140,27 @@
return true;
}
- bool FlushBlocking(ErrorPtr* error) { return true; }
-
- bool CloseBlocking(ErrorPtr* error) {
- weak_ptr_factory_.InvalidateWeakPtrs();
- return true;
- }
-
void CancelPendingAsyncOperations() {
weak_ptr_factory_.InvalidateWeakPtrs();
}
- bool Init() {
+ bool Init(const std::string& host, uint16_t port) {
ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
CHECK(ctx_);
ssl_.reset(SSL_new(ctx_.get()));
- char endpoint[] = "talk.google.com:5223";
- stream_bio_ = BIO_new_connect(endpoint);
- CHECK(stream_bio_);
- BIO_set_nbio(stream_bio_, 1);
+ char end_point[255];
+ snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port);
+ BIO* stream_bio = BIO_new_connect(end_point);
+ CHECK(stream_bio);
+ BIO_set_nbio(stream_bio, 1);
- while (BIO_do_connect(stream_bio_) != 1) {
- CHECK(BIO_should_retry(stream_bio_));
+ while (BIO_do_connect(stream_bio) != 1) {
+ CHECK(BIO_should_retry(stream_bio));
sleep(1);
}
- SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_);
+ SSL_set_bio(ssl_.get(), stream_bio, stream_bio);
SSL_set_connect_state(ssl_.get());
for (;;) {
@@ -339,7 +184,6 @@
TaskRunner* task_runner_{nullptr};
std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free};
std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free};
- BIO* stream_bio_{nullptr};
base::WeakPtrFactory<SSLStream> weak_ptr_factory_{this};
};
@@ -504,23 +348,15 @@
i.Run(online);
}
-std::unique_ptr<Stream> NetworkImpl::OpenSocketBlocking(const std::string& host,
- uint16_t port) {
- std::unique_ptr<SocketStream> stream{new SocketStream{task_runner_}};
- if (!stream->Connect(host, port))
- return nullptr;
- return std::move(stream);
-}
-
-void NetworkImpl::CreateTlsStream(
- std::unique_ptr<Stream> stream,
+void NetworkImpl::OpenSslSocket(
const std::string& host,
+ uint16_t port,
const base::Callback<void(std::unique_ptr<Stream>)>& success_callback,
const base::Callback<void(const Error*)>& error_callback) {
// Connect to SSL port instead of upgrading to TLS.
std::unique_ptr<SSLStream> tls_stream{new SSLStream{task_runner_}};
- if (tls_stream->Init()) {
+ if (tls_stream->Init(host, port)) {
task_runner_->PostDelayedTask(
FROM_HERE, base::Bind(success_callback, base::Passed(&tls_stream)), {});
} else {
diff --git a/libweave/examples/ubuntu/network_manager.h b/libweave/examples/ubuntu/network_manager.h
index c58d220..c9d6ff4 100644
--- a/libweave/examples/ubuntu/network_manager.h
+++ b/libweave/examples/ubuntu/network_manager.h
@@ -34,11 +34,9 @@
NetworkState GetConnectionState() const override;
void EnableAccessPoint(const std::string& ssid) override;
void DisableAccessPoint() override;
- std::unique_ptr<Stream> OpenSocketBlocking(const std::string& host,
- uint16_t port) override;
- void CreateTlsStream(
- std::unique_ptr<Stream> stream,
+ void OpenSslSocket(
const std::string& host,
+ uint16_t port,
const base::Callback<void(std::unique_ptr<Stream>)>& success_callback,
const base::Callback<void(const Error*)>& error_callback) override;
diff --git a/libweave/include/weave/network.h b/libweave/include/weave/network.h
index 41e5413..0f0b12b 100644
--- a/libweave/include/weave/network.h
+++ b/libweave/include/weave/network.h
@@ -48,14 +48,9 @@
virtual void DisableAccessPoint() = 0;
// Opens bidirectional sockets and returns attached stream.
- // TODO(vitalybuka): Make async.
- virtual std::unique_ptr<Stream> OpenSocketBlocking(const std::string& host,
- uint16_t port) = 0;
-
- // Replaces stream with version with TLS support.
- virtual void CreateTlsStream(
- std::unique_ptr<Stream> socket,
+ virtual void OpenSslSocket(
const std::string& host,
+ uint16_t port,
const base::Callback<void(std::unique_ptr<Stream>)>& success_callback,
const base::Callback<void(const Error*)>& error_callback) = 0;
diff --git a/libweave/include/weave/stream.h b/libweave/include/weave/stream.h
index 3ecdba6..9d4e657 100644
--- a/libweave/include/weave/stream.h
+++ b/libweave/include/weave/stream.h
@@ -30,10 +30,6 @@
const base::Callback<void(const Error*)>& error_callback,
ErrorPtr* error) = 0;
- virtual bool FlushBlocking(ErrorPtr* error) = 0;
-
- virtual bool CloseBlocking(ErrorPtr* error) = 0;
-
virtual void CancelPendingAsyncOperations() = 0;
};
diff --git a/libweave/include/weave/test/mock_network.h b/libweave/include/weave/test/mock_network.h
index f09533e..f188223 100644
--- a/libweave/include/weave/test/mock_network.h
+++ b/libweave/include/weave/test/mock_network.h
@@ -30,22 +30,11 @@
MOCK_METHOD1(EnableAccessPoint, void(const std::string&));
MOCK_METHOD0(DisableAccessPoint, void());
- MOCK_METHOD2(MockOpenSocketBlocking, Stream*(const std::string&, uint16_t));
- MOCK_METHOD2(MockCreateTlsStream, Stream*(Stream*, const std::string&));
-
- std::unique_ptr<Stream> OpenSocketBlocking(const std::string& host,
- uint16_t port) override {
- return std::unique_ptr<Stream>{MockOpenSocketBlocking(host, port)};
- }
-
- void CreateTlsStream(
- std::unique_ptr<Stream> socket,
- const std::string& host,
- const base::Callback<void(std::unique_ptr<Stream>)>& success_callback,
- const base::Callback<void(const Error*)>& error_callback) override {
- success_callback.Run(
- std::unique_ptr<Stream>{MockCreateTlsStream(socket.get(), host)});
- }
+ MOCK_METHOD4(OpenSslSocket,
+ void(const std::string&,
+ uint16_t,
+ const base::Callback<void(std::unique_ptr<Stream>)>&,
+ const base::Callback<void(const Error*)>&));
};
} // namespace test
diff --git a/libweave/src/notification/xmpp_channel.cc b/libweave/src/notification/xmpp_channel.cc
index dc9f851..ad4e233 100644
--- a/libweave/src/notification/xmpp_channel.cc
+++ b/libweave/src/notification/xmpp_channel.cc
@@ -75,7 +75,7 @@
};
const char kDefaultXmppHost[] = "talk.google.com";
-const uint16_t kDefaultXmppPort = 5222;
+const uint16_t kDefaultXmppPort = 5223;
// Used for keeping connection alive.
const int kRegularPingIntervalSeconds = 60;
@@ -152,20 +152,6 @@
switch (state_) {
case XmppState::kConnected:
- if (stanza->name() == "stream:features" &&
- stanza->FindFirstChild("starttls/required", false)) {
- state_ = XmppState::kTlsStarted;
- SendMessage("<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>");
- return;
- }
- break;
- case XmppState::kTlsStarted:
- if (stanza->name() == "proceed") {
- StartTlsHandshake();
- return;
- }
- break;
- case XmppState::kTlsCompleted:
if (stanza->name() == "stream:features") {
auto children = stanza->FindChildren("mechanisms/mechanism", false);
for (const auto& child : children) {
@@ -217,7 +203,7 @@
}
// Something bad happened. Close the stream and start over.
LOG(ERROR) << "Error condition occurred handling stanza: "
- << stanza->ToString();
+ << stanza->ToString() << " in state: " << static_cast<int>(state_);
CloseStream();
}
@@ -290,25 +276,37 @@
ParseNotificationJson(*json_dict, delegate_);
}
-void XmppChannel::StartTlsHandshake() {
- raw_socket_->CancelPendingAsyncOperations();
- network_->CreateTlsStream(
- std::move(raw_socket_), host_,
- base::Bind(&XmppChannel::OnTlsHandshakeComplete,
+void XmppChannel::CreateSslSocket() {
+ CHECK(!stream_);
+ state_ = XmppState::kConnecting;
+ LOG(INFO) << "Starting XMPP connection to " << kDefaultXmppHost << ":" << kDefaultXmppPort;
+
+ network_->OpenSslSocket(
+ kDefaultXmppHost, kDefaultXmppPort,
+ base::Bind(&XmppChannel::OnSslSocketReady,
task_ptr_factory_.GetWeakPtr()),
- base::Bind(&XmppChannel::OnTlsError, task_ptr_factory_.GetWeakPtr()));
+ base::Bind(&XmppChannel::OnSslError, task_ptr_factory_.GetWeakPtr()));
}
-void XmppChannel::OnTlsHandshakeComplete(std::unique_ptr<Stream> tls_stream) {
- tls_stream_ = std::move(tls_stream);
- stream_ = tls_stream_.get();
- state_ = XmppState::kTlsCompleted;
+void XmppChannel::OnSslSocketReady(std::unique_ptr<Stream> stream) {
+ CHECK(XmppState::kConnecting == state_);
+ backoff_entry_.InformOfRequest(true);
+ stream_ = std::move(stream);
+ state_ = XmppState::kConnected;
RestartXmppStream();
+ ScheduleRegularPing();
}
-void XmppChannel::OnTlsError(const Error* error) {
+void XmppChannel::OnSslError(const Error* error) {
LOG(ERROR) << "TLS handshake failed. Restarting XMPP connection";
- Restart();
+ backoff_entry_.InformOfRequest(false);
+
+ LOG(INFO) << "Delaying connection to XMPP server for "
+ << backoff_entry_.GetTimeUntilRelease();
+ task_runner_->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(&XmppChannel::CreateSslSocket, task_ptr_factory_.GetWeakPtr()),
+ backoff_entry_.GetTimeUntilRelease());
}
void XmppChannel::SendMessage(const std::string& message) {
@@ -336,10 +334,6 @@
void XmppChannel::OnMessageSent() {
ErrorPtr error;
write_pending_ = false;
- if (!stream_->FlushBlocking(&error)) {
- OnWriteError(error.get());
- return;
- }
if (queued_write_data_.empty()) {
WaitForMessage();
} else {
@@ -373,30 +367,6 @@
Restart();
}
-void XmppChannel::Connect(const std::string& host,
- uint16_t port,
- const base::Closure& callback) {
- state_ = XmppState::kConnecting;
- LOG(INFO) << "Starting XMPP connection to " << host << ":" << port;
- raw_socket_ = network_->OpenSocketBlocking(host, port);
-
- backoff_entry_.InformOfRequest(raw_socket_ != nullptr);
- if (raw_socket_) {
- host_ = host;
- port_ = port;
- stream_ = raw_socket_.get();
- callback.Run();
- } else {
- LOG(INFO) << "Delaying connection to XMPP server " << host << " for "
- << backoff_entry_.GetTimeUntilRelease();
- task_runner_->PostDelayedTask(
- FROM_HERE,
- base::Bind(&XmppChannel::Connect, task_ptr_factory_.GetWeakPtr(), host,
- port, callback),
- backoff_entry_.GetTimeUntilRelease());
- }
-}
-
std::string XmppChannel::GetName() const {
return "xmpp";
}
@@ -419,9 +389,7 @@
CHECK(state_ == XmppState::kNotStarted);
delegate_ = delegate;
- Connect(
- kDefaultXmppHost, kDefaultXmppPort,
- base::Bind(&XmppChannel::OnConnected, task_ptr_factory_.GetWeakPtr()));
+ CreateSslSocket();
}
void XmppChannel::Stop() {
@@ -431,25 +399,10 @@
task_ptr_factory_.InvalidateWeakPtrs();
ping_ptr_factory_.InvalidateWeakPtrs();
- if (tls_stream_) {
- tls_stream_->CloseBlocking(nullptr);
- tls_stream_.reset();
- }
- if (raw_socket_) {
- raw_socket_->CloseBlocking(nullptr);
- raw_socket_.reset();
- }
- stream_ = nullptr;
+ stream_.reset();
state_ = XmppState::kNotStarted;
}
-void XmppChannel::OnConnected() {
- CHECK(XmppState::kConnecting == state_);
- state_ = XmppState::kConnected;
- RestartXmppStream();
- ScheduleRegularPing();
-}
-
void XmppChannel::RestartXmppStream() {
stream_parser_.Reset();
stream_->CancelPendingAsyncOperations();
diff --git a/libweave/src/notification/xmpp_channel.h b/libweave/src/notification/xmpp_channel.h
index 112799f..ddbdee1 100644
--- a/libweave/src/notification/xmpp_channel.h
+++ b/libweave/src/notification/xmpp_channel.h
@@ -61,8 +61,6 @@
kNotStarted,
kConnecting,
kConnected,
- kTlsStarted,
- kTlsCompleted,
kAuthenticationStarted,
kAuthenticationFailed,
kStreamRestartedPostAuthentication,
@@ -75,20 +73,13 @@
protected:
// These methods are internal helpers that can be overloaded by unit tests
// to help provide unit-test-specific functionality.
- virtual void Connect(const std::string& host,
- uint16_t port,
- const base::Closure& callback);
virtual void SchedulePing(base::TimeDelta interval, base::TimeDelta timeout);
void ScheduleRegularPing();
void ScheduleFastPing();
- XmppState state_{XmppState::kNotStarted};
-
- // The connection socket stream to the XMPP server.
- Stream* stream_{nullptr};
-
private:
friend class IqStanzaHandler;
+ friend class FakeXmppChannel;
// Overrides from XmppStreamParser::Delegate.
void OnStreamStart(const std::string& node_name,
@@ -103,13 +94,12 @@
void HandleMessageStanza(std::unique_ptr<XmlNode> stanza);
void RestartXmppStream();
- void StartTlsHandshake();
- void OnTlsHandshakeComplete(std::unique_ptr<Stream> tls_stream);
- void OnTlsError(const Error* error);
+ void CreateSslSocket();
+ void OnSslSocketReady(std::unique_ptr<Stream> stream);
+ void OnSslError(const Error* error);
void WaitForMessage();
- void OnConnected();
void OnMessageRead(size_t size);
void OnMessageSent();
void OnReadError(const Error* error);
@@ -130,6 +120,8 @@
void OnConnectivityChanged(bool online);
+ XmppState state_{XmppState::kNotStarted};
+
// Robot account name for the device.
std::string account_;
@@ -140,8 +132,7 @@
std::string access_token_;
Network* network_{nullptr};
- std::unique_ptr<Stream> raw_socket_;
- std::unique_ptr<Stream> tls_stream_; // Must follow |raw_socket_|.
+ std::unique_ptr<Stream> stream_;
// Read buffer for incoming message packets.
std::vector<char> read_socket_data_;
diff --git a/libweave/src/notification/xmpp_channel_unittest.cc b/libweave/src/notification/xmpp_channel_unittest.cc
index 574dfa2..d7c8174 100644
--- a/libweave/src/notification/xmpp_channel_unittest.cc
+++ b/libweave/src/notification/xmpp_channel_unittest.cc
@@ -8,11 +8,16 @@
#include <queue>
#include <gtest/gtest.h>
+#include <weave/test/mock_network.h>
#include <weave/test/mock_task_runner.h>
#include "libweave/src/bind_lambda.h"
+using testing::_;
+using testing::Invoke;
+using testing::Return;
using testing::StrictMock;
+using testing::WithArgs;
namespace weave {
@@ -21,16 +26,22 @@
constexpr char kAccountName[] = "Account@Name";
constexpr char kAccessToken[] = "AccessToken";
+constexpr char kStartStreamMessage[] =
+ "<stream:stream to='clouddevices.gserviceaccount.com' "
+ "xmlns:stream='http://etherx.jabber.org/streams' xml:lang='*' "
+ "version='1.0' xmlns='jabber:client'>";
constexpr char kStartStreamResponse[] =
"<stream:stream from=\"clouddevices.gserviceaccount.com\" "
"id=\"0CCF520913ABA04B\" version=\"1.0\" "
"xmlns:stream=\"http://etherx.jabber.org/streams\" "
- "xmlns=\"jabber:client\">"
- "<stream:features><starttls xmlns=\"urn:ietf:params:xml:ns:xmpp-tls\">"
- "<required/></starttls><mechanisms "
- "xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\"><mechanism>X-OAUTH2</mechanism>"
- "<mechanism>X-GOOGLE-TOKEN</mechanism></mechanisms></stream:features>";
-constexpr char kTlsStreamResponse[] =
+ "xmlns=\"jabber:client\">";
+constexpr char kAuthenticationMessage[] =
+ "<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='X-OAUTH2' "
+ "auth:service='oauth2' auth:allow-non-google-login='true' "
+ "auth:client-uses-full-bind-result='true' "
+ "xmlns:auth='http://www.google.com/talk/protocol/auth'>"
+ "AEFjY291bnRATmFtZQBBY2Nlc3NUb2tlbg==</auth>";
+constexpr char kConnectedResponse[] =
"<stream:features><mechanisms xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\">"
"<mechanism>X-OAUTH2</mechanism>"
"<mechanism>X-GOOGLE-TOKEN</mechanism></mechanisms></stream:features>";
@@ -55,18 +66,6 @@
"19853128\" from=\""
"110cc78f78d7032cc7bf2c6e14c1fa7d@clouddevices.gserviceaccount.com\" "
"id=\"3\" type=\"result\"/>";
-constexpr char kStartStreamMessage[] =
- "<stream:stream to='clouddevices.gserviceaccount.com' "
- "xmlns:stream='http://etherx.jabber.org/streams' xml:lang='*' "
- "version='1.0' xmlns='jabber:client'>";
-constexpr char kStartTlsMessage[] =
- "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>";
-constexpr char kAuthenticationMessage[] =
- "<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='X-OAUTH2' "
- "auth:service='oauth2' auth:allow-non-google-login='true' "
- "auth:client-uses-full-bind-result='true' "
- "xmlns:auth='http://www.google.com/talk/protocol/auth'>"
- "AEFjY291bnRATmFtZQBBY2Nlc3NUb2tlbg==</auth>";
constexpr char kBindMessage[] =
"<iq id='1' type='set'><bind "
"xmlns='urn:ietf:params:xml:ns:xmpp-bind'/></iq>";
@@ -84,10 +83,6 @@
public:
explicit FakeStream(TaskRunner* task_runner) : task_runner_{task_runner} {}
- bool FlushBlocking(ErrorPtr* error) override { return true; }
-
- bool CloseBlocking(ErrorPtr* error) override { return true; }
-
void CancelPendingAsyncOperations() override {}
void ExpectWritePacketString(base::TimeDelta, const std::string& data) {
@@ -134,35 +129,55 @@
class FakeXmppChannel : public XmppChannel {
public:
- explicit FakeXmppChannel(TaskRunner* task_runner)
- : XmppChannel{kAccountName, kAccessToken, task_runner, nullptr},
- fake_stream_{task_runner} {}
+ explicit FakeXmppChannel(TaskRunner* task_runner, weave::Network* network)
+ : XmppChannel{kAccountName, kAccessToken, task_runner, network},
+ stream_{new FakeStream{task_runner_}},
+ fake_stream_{stream_.get()} {}
+
+ void Connect(
+ const base::Callback<void(std::unique_ptr<weave::Stream>)>& callback) {
+ callback.Run(std::move(stream_));
+ }
XmppState state() const { return state_; }
void set_state(XmppState state) { state_ = state; }
- void Connect(const std::string& host,
- uint16_t port,
- const base::Closure& callback) override {
- set_state(XmppState::kConnecting);
- stream_ = &fake_stream_;
- callback.Run();
- }
-
void SchedulePing(base::TimeDelta interval,
base::TimeDelta timeout) override {}
- FakeStream fake_stream_;
+ void ExpectWritePacketString(base::TimeDelta delta, const std::string& data) {
+ fake_stream_->ExpectWritePacketString(delta, data);
+ }
+
+ void AddReadPacketString(base::TimeDelta delta, const std::string& data) {
+ fake_stream_->AddReadPacketString(delta, data);
+ }
+
+ std::unique_ptr<FakeStream> stream_;
+ FakeStream* fake_stream_{nullptr};
+};
+
+class MockNetwork : public weave::test::MockNetwork {
+ public:
+ MockNetwork() {
+ EXPECT_CALL(*this, AddOnConnectionChangedCallback(_))
+ .WillRepeatedly(Return());
+ }
};
class XmppChannelTest : public ::testing::Test {
protected:
+ XmppChannelTest() {
+ EXPECT_CALL(network_, OpenSslSocket("talk.google.com", 5223, _, _))
+ .WillOnce(
+ WithArgs<2>(Invoke(&xmpp_client_, &FakeXmppChannel::Connect)));
+ }
+
void StartStream() {
- xmpp_client_.fake_stream_.ExpectWritePacketString({}, kStartStreamMessage);
- xmpp_client_.fake_stream_.AddReadPacketString({}, kStartStreamResponse);
- xmpp_client_.fake_stream_.ExpectWritePacketString({}, kStartTlsMessage);
+ xmpp_client_.ExpectWritePacketString({}, kStartStreamMessage);
+ xmpp_client_.AddReadPacketString({}, kStartStreamResponse);
xmpp_client_.Start(nullptr);
- RunUntil(XmppChannel::XmppState::kTlsStarted);
+ RunUntil(XmppChannel::XmppState::kConnected);
}
void StartWithState(XmppChannel::XmppState state) {
@@ -177,12 +192,13 @@
}
StrictMock<test::MockTaskRunner> task_runner_;
- FakeXmppChannel xmpp_client_{&task_runner_};
+ StrictMock<MockNetwork> network_;
+ FakeXmppChannel xmpp_client_{&task_runner_, &network_};
};
TEST_F(XmppChannelTest, StartStream) {
EXPECT_EQ(XmppChannel::XmppState::kNotStarted, xmpp_client_.state());
- xmpp_client_.fake_stream_.ExpectWritePacketString({}, kStartStreamMessage);
+ xmpp_client_.ExpectWritePacketString({}, kStartStreamMessage);
xmpp_client_.Start(nullptr);
RunUntil(XmppChannel::XmppState::kConnected);
}
@@ -191,48 +207,46 @@
StartStream();
}
-TEST_F(XmppChannelTest, HandleTLSCompleted) {
- StartWithState(XmppChannel::XmppState::kTlsCompleted);
- xmpp_client_.fake_stream_.AddReadPacketString({}, kTlsStreamResponse);
- xmpp_client_.fake_stream_.ExpectWritePacketString({}, kAuthenticationMessage);
+TEST_F(XmppChannelTest, HandleConnected) {
+ StartWithState(XmppChannel::XmppState::kConnected);
+ xmpp_client_.AddReadPacketString({}, kConnectedResponse);
+ xmpp_client_.ExpectWritePacketString({}, kAuthenticationMessage);
RunUntil(XmppChannel::XmppState::kAuthenticationStarted);
}
TEST_F(XmppChannelTest, HandleAuthenticationSucceededResponse) {
StartWithState(XmppChannel::XmppState::kAuthenticationStarted);
- xmpp_client_.fake_stream_.AddReadPacketString(
- {}, kAuthenticationSucceededResponse);
- xmpp_client_.fake_stream_.ExpectWritePacketString({}, kStartStreamMessage);
+ xmpp_client_.AddReadPacketString({}, kAuthenticationSucceededResponse);
+ xmpp_client_.ExpectWritePacketString({}, kStartStreamMessage);
RunUntil(XmppChannel::XmppState::kStreamRestartedPostAuthentication);
}
TEST_F(XmppChannelTest, HandleAuthenticationFailedResponse) {
StartWithState(XmppChannel::XmppState::kAuthenticationStarted);
- xmpp_client_.fake_stream_.AddReadPacketString({},
- kAuthenticationFailedResponse);
+ xmpp_client_.AddReadPacketString({}, kAuthenticationFailedResponse);
RunUntil(XmppChannel::XmppState::kAuthenticationFailed);
}
TEST_F(XmppChannelTest, HandleStreamRestartedResponse) {
StartWithState(XmppChannel::XmppState::kStreamRestartedPostAuthentication);
- xmpp_client_.fake_stream_.AddReadPacketString({}, kRestartStreamResponse);
- xmpp_client_.fake_stream_.ExpectWritePacketString({}, kBindMessage);
+ xmpp_client_.AddReadPacketString({}, kRestartStreamResponse);
+ xmpp_client_.ExpectWritePacketString({}, kBindMessage);
RunUntil(XmppChannel::XmppState::kBindSent);
EXPECT_TRUE(xmpp_client_.jid().empty());
- xmpp_client_.fake_stream_.AddReadPacketString({}, kBindResponse);
- xmpp_client_.fake_stream_.ExpectWritePacketString({}, kSessionMessage);
+ xmpp_client_.AddReadPacketString({}, kBindResponse);
+ xmpp_client_.ExpectWritePacketString({}, kSessionMessage);
RunUntil(XmppChannel::XmppState::kSessionStarted);
EXPECT_EQ(
"110cc78f78d7032cc7bf2c6e14c1fa7d@clouddevices.gserviceaccount.com"
"/19853128",
xmpp_client_.jid());
- xmpp_client_.fake_stream_.AddReadPacketString({}, kSessionResponse);
- xmpp_client_.fake_stream_.ExpectWritePacketString({}, kSubscribeMessage);
+ xmpp_client_.AddReadPacketString({}, kSessionResponse);
+ xmpp_client_.ExpectWritePacketString({}, kSubscribeMessage);
RunUntil(XmppChannel::XmppState::kSubscribeStarted);
- xmpp_client_.fake_stream_.AddReadPacketString({}, kSubscribedResponse);
+ xmpp_client_.AddReadPacketString({}, kSubscribedResponse);
RunUntil(XmppChannel::XmppState::kSubscribed);
}