Switch from starttls to SSL sockets SSL sockets simplifies both libweave and device side. Also removed Flush and Close method, as not needed with current design. BUG=b:23529805 Change-Id: I6e2a5522f8a55d20b0ed6363a9955d37505d2c06
diff --git a/libweave/examples/ubuntu/network_manager.cc b/libweave/examples/ubuntu/network_manager.cc index 48c817c..ab1d521 100644 --- a/libweave/examples/ubuntu/network_manager.cc +++ b/libweave/examples/ubuntu/network_manager.cc
@@ -39,160 +39,11 @@ NOTREACHED(); } -class SocketStream : public Stream { - public: - explicit SocketStream(TaskRunner* task_runner) : task_runner_{task_runner} {} - - ~SocketStream() { CloseBlocking(nullptr); } - - void RunDelayedTask(const base::Closure& success_callback) { - success_callback.Run(); - } - - bool ReadAsync(void* buffer, - size_t size_to_read, - const base::Callback<void(size_t)>& success_callback, - const base::Callback<void(const Error*)>& error_callback, - ErrorPtr* error) { - if (socket_fd_ < 0) { - Error::AddTo(error, FROM_HERE, "socket", "invalid_socket", - strerror(errno)); - return false; - } - int size_read = recv(socket_fd_, buffer, size_to_read, MSG_DONTWAIT); - if (size_read > 0) { - task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(&SocketStream::RunDelayedTask, - weak_ptr_factory_.GetWeakPtr(), - base::Bind(success_callback, size_read)), - {}); - return true; - } - if (errno == EAGAIN || errno == EWOULDBLOCK) { - task_runner_->PostDelayedTask( - FROM_HERE, - base::Bind(base::IgnoreResult(&SocketStream::ReadAsync), - weak_ptr_factory_.GetWeakPtr(), buffer, size_to_read, - success_callback, error_callback, nullptr), - base::TimeDelta::FromMilliseconds(200)); - return true; - } - - ErrorPtr recv_error; - Error::AddTo(&recv_error, FROM_HERE, "socket", "socket_recv_failed", - strerror(errno)); - task_runner_->PostDelayedTask( - FROM_HERE, - base::Bind(error_callback, base::Owned(recv_error.release())), {}); - return true; - } - - bool WriteAllAsync(const void* buffer, - size_t size_to_write, - const base::Closure& success_callback, - const base::Callback<void(const Error*)>& error_callback, - ErrorPtr* error) { - if (socket_fd_ < 0) { - Error::AddTo(error, FROM_HERE, "socket", "invalid_socket", - strerror(errno)); - return false; - } - const char* buffer_ptr = static_cast<const char*>(buffer); - do { - int size_sent = send(socket_fd_, buffer_ptr, size_to_write, 0); - if (size_sent <= 0) { - ErrorPtr send_error; - Error::AddTo(&send_error, FROM_HERE, "socket", "socket_send_failed", - strerror(errno)); - task_runner_->PostDelayedTask( - FROM_HERE, - base::Bind(error_callback, base::Owned(send_error.release())), {}); - // Still true as we return error with callback. - return true; - } - size_to_write -= size_sent; - buffer_ptr += size_sent; - } while (size_to_write > 0); - - task_runner_->PostDelayedTask(FROM_HERE, success_callback, {}); - return true; - } - bool FlushBlocking(ErrorPtr* error) { return true; } - - bool CloseBlocking(ErrorPtr* error) { - weak_ptr_factory_.InvalidateWeakPtrs(); - if (socket_fd_ >= 0) { - close(socket_fd_); - socket_fd_ = -1; - } - } - - void CancelPendingAsyncOperations() { - weak_ptr_factory_.InvalidateWeakPtrs(); - } - - bool Connect(const std::string& host, uint16_t port) { - std::string service = std::to_string(port); - addrinfo hints = {0, AF_UNSPEC, SOCK_STREAM}; - addrinfo* result = nullptr; - if (getaddrinfo(host.c_str(), service.c_str(), &hints, &result)) { - LOG(ERROR) << "Failed to resolve host name: " << host; - return false; - } - std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> result_deleter{ - result, &freeaddrinfo}; - - for (const addrinfo* info = result; info != nullptr; info = info->ai_next) { - socket_fd_ = - socket(info->ai_family, info->ai_socktype, info->ai_protocol); - if (socket_fd_ < 0) - continue; - - int flags = fcntl(socket_fd_, F_GETFL, 0); - if (flags == -1) - flags = 0; - fcntl(socket_fd_, F_SETFL, flags | O_NONBLOCK); - - LOG(INFO) << "Connecting..."; - if (connect(socket_fd_, info->ai_addr, info->ai_addrlen) == 0) - break; // Success. - - if (errno == EINPROGRESS) { - fd_set write_fds; - FD_ZERO(&write_fds); - FD_SET(socket_fd_, &write_fds); - - struct timeval tv; - tv.tv_sec = 5; - tv.tv_usec = 0; - - int select_ret = select(socket_fd_ + 1, NULL, &write_fds, NULL, &tv); - if (select_ret != -1 && select_ret != 0) { - break; - } - } - - LOG(ERROR) << "Failed to connect"; - CloseBlocking(nullptr); - } - - return socket_fd_ >= 0; - } - - int GetFd() const { return socket_fd_; } - - private: - TaskRunner* task_runner_{nullptr}; - int socket_fd_{-1}; - - base::WeakPtrFactory<SocketStream> weak_ptr_factory_{this}; -}; - class SSLStream : public Stream { public: explicit SSLStream(TaskRunner* task_runner) : task_runner_{task_runner} {} - ~SSLStream() { weak_ptr_factory_.InvalidateWeakPtrs(); } + ~SSLStream() { CancelPendingAsyncOperations(); } void RunDelayedTask(const base::Closure& success_callback) { success_callback.Run(); @@ -289,33 +140,27 @@ return true; } - bool FlushBlocking(ErrorPtr* error) { return true; } - - bool CloseBlocking(ErrorPtr* error) { - weak_ptr_factory_.InvalidateWeakPtrs(); - return true; - } - void CancelPendingAsyncOperations() { weak_ptr_factory_.InvalidateWeakPtrs(); } - bool Init() { + bool Init(const std::string& host, uint16_t port) { ctx_.reset(SSL_CTX_new(TLSv1_2_client_method())); CHECK(ctx_); ssl_.reset(SSL_new(ctx_.get())); - char endpoint[] = "talk.google.com:5223"; - stream_bio_ = BIO_new_connect(endpoint); - CHECK(stream_bio_); - BIO_set_nbio(stream_bio_, 1); + char end_point[255]; + snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port); + BIO* stream_bio = BIO_new_connect(end_point); + CHECK(stream_bio); + BIO_set_nbio(stream_bio, 1); - while (BIO_do_connect(stream_bio_) != 1) { - CHECK(BIO_should_retry(stream_bio_)); + while (BIO_do_connect(stream_bio) != 1) { + CHECK(BIO_should_retry(stream_bio)); sleep(1); } - SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_); + SSL_set_bio(ssl_.get(), stream_bio, stream_bio); SSL_set_connect_state(ssl_.get()); for (;;) { @@ -339,7 +184,6 @@ TaskRunner* task_runner_{nullptr}; std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free}; std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free}; - BIO* stream_bio_{nullptr}; base::WeakPtrFactory<SSLStream> weak_ptr_factory_{this}; }; @@ -504,23 +348,15 @@ i.Run(online); } -std::unique_ptr<Stream> NetworkImpl::OpenSocketBlocking(const std::string& host, - uint16_t port) { - std::unique_ptr<SocketStream> stream{new SocketStream{task_runner_}}; - if (!stream->Connect(host, port)) - return nullptr; - return std::move(stream); -} - -void NetworkImpl::CreateTlsStream( - std::unique_ptr<Stream> stream, +void NetworkImpl::OpenSslSocket( const std::string& host, + uint16_t port, const base::Callback<void(std::unique_ptr<Stream>)>& success_callback, const base::Callback<void(const Error*)>& error_callback) { // Connect to SSL port instead of upgrading to TLS. std::unique_ptr<SSLStream> tls_stream{new SSLStream{task_runner_}}; - if (tls_stream->Init()) { + if (tls_stream->Init(host, port)) { task_runner_->PostDelayedTask( FROM_HERE, base::Bind(success_callback, base::Passed(&tls_stream)), {}); } else {
diff --git a/libweave/examples/ubuntu/network_manager.h b/libweave/examples/ubuntu/network_manager.h index c58d220..c9d6ff4 100644 --- a/libweave/examples/ubuntu/network_manager.h +++ b/libweave/examples/ubuntu/network_manager.h
@@ -34,11 +34,9 @@ NetworkState GetConnectionState() const override; void EnableAccessPoint(const std::string& ssid) override; void DisableAccessPoint() override; - std::unique_ptr<Stream> OpenSocketBlocking(const std::string& host, - uint16_t port) override; - void CreateTlsStream( - std::unique_ptr<Stream> stream, + void OpenSslSocket( const std::string& host, + uint16_t port, const base::Callback<void(std::unique_ptr<Stream>)>& success_callback, const base::Callback<void(const Error*)>& error_callback) override;
diff --git a/libweave/include/weave/network.h b/libweave/include/weave/network.h index 41e5413..0f0b12b 100644 --- a/libweave/include/weave/network.h +++ b/libweave/include/weave/network.h
@@ -48,14 +48,9 @@ virtual void DisableAccessPoint() = 0; // Opens bidirectional sockets and returns attached stream. - // TODO(vitalybuka): Make async. - virtual std::unique_ptr<Stream> OpenSocketBlocking(const std::string& host, - uint16_t port) = 0; - - // Replaces stream with version with TLS support. - virtual void CreateTlsStream( - std::unique_ptr<Stream> socket, + virtual void OpenSslSocket( const std::string& host, + uint16_t port, const base::Callback<void(std::unique_ptr<Stream>)>& success_callback, const base::Callback<void(const Error*)>& error_callback) = 0;
diff --git a/libweave/include/weave/stream.h b/libweave/include/weave/stream.h index 3ecdba6..9d4e657 100644 --- a/libweave/include/weave/stream.h +++ b/libweave/include/weave/stream.h
@@ -30,10 +30,6 @@ const base::Callback<void(const Error*)>& error_callback, ErrorPtr* error) = 0; - virtual bool FlushBlocking(ErrorPtr* error) = 0; - - virtual bool CloseBlocking(ErrorPtr* error) = 0; - virtual void CancelPendingAsyncOperations() = 0; };
diff --git a/libweave/include/weave/test/mock_network.h b/libweave/include/weave/test/mock_network.h index f09533e..f188223 100644 --- a/libweave/include/weave/test/mock_network.h +++ b/libweave/include/weave/test/mock_network.h
@@ -30,22 +30,11 @@ MOCK_METHOD1(EnableAccessPoint, void(const std::string&)); MOCK_METHOD0(DisableAccessPoint, void()); - MOCK_METHOD2(MockOpenSocketBlocking, Stream*(const std::string&, uint16_t)); - MOCK_METHOD2(MockCreateTlsStream, Stream*(Stream*, const std::string&)); - - std::unique_ptr<Stream> OpenSocketBlocking(const std::string& host, - uint16_t port) override { - return std::unique_ptr<Stream>{MockOpenSocketBlocking(host, port)}; - } - - void CreateTlsStream( - std::unique_ptr<Stream> socket, - const std::string& host, - const base::Callback<void(std::unique_ptr<Stream>)>& success_callback, - const base::Callback<void(const Error*)>& error_callback) override { - success_callback.Run( - std::unique_ptr<Stream>{MockCreateTlsStream(socket.get(), host)}); - } + MOCK_METHOD4(OpenSslSocket, + void(const std::string&, + uint16_t, + const base::Callback<void(std::unique_ptr<Stream>)>&, + const base::Callback<void(const Error*)>&)); }; } // namespace test
diff --git a/libweave/src/notification/xmpp_channel.cc b/libweave/src/notification/xmpp_channel.cc index dc9f851..ad4e233 100644 --- a/libweave/src/notification/xmpp_channel.cc +++ b/libweave/src/notification/xmpp_channel.cc
@@ -75,7 +75,7 @@ }; const char kDefaultXmppHost[] = "talk.google.com"; -const uint16_t kDefaultXmppPort = 5222; +const uint16_t kDefaultXmppPort = 5223; // Used for keeping connection alive. const int kRegularPingIntervalSeconds = 60; @@ -152,20 +152,6 @@ switch (state_) { case XmppState::kConnected: - if (stanza->name() == "stream:features" && - stanza->FindFirstChild("starttls/required", false)) { - state_ = XmppState::kTlsStarted; - SendMessage("<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>"); - return; - } - break; - case XmppState::kTlsStarted: - if (stanza->name() == "proceed") { - StartTlsHandshake(); - return; - } - break; - case XmppState::kTlsCompleted: if (stanza->name() == "stream:features") { auto children = stanza->FindChildren("mechanisms/mechanism", false); for (const auto& child : children) { @@ -217,7 +203,7 @@ } // Something bad happened. Close the stream and start over. LOG(ERROR) << "Error condition occurred handling stanza: " - << stanza->ToString(); + << stanza->ToString() << " in state: " << static_cast<int>(state_); CloseStream(); } @@ -290,25 +276,37 @@ ParseNotificationJson(*json_dict, delegate_); } -void XmppChannel::StartTlsHandshake() { - raw_socket_->CancelPendingAsyncOperations(); - network_->CreateTlsStream( - std::move(raw_socket_), host_, - base::Bind(&XmppChannel::OnTlsHandshakeComplete, +void XmppChannel::CreateSslSocket() { + CHECK(!stream_); + state_ = XmppState::kConnecting; + LOG(INFO) << "Starting XMPP connection to " << kDefaultXmppHost << ":" << kDefaultXmppPort; + + network_->OpenSslSocket( + kDefaultXmppHost, kDefaultXmppPort, + base::Bind(&XmppChannel::OnSslSocketReady, task_ptr_factory_.GetWeakPtr()), - base::Bind(&XmppChannel::OnTlsError, task_ptr_factory_.GetWeakPtr())); + base::Bind(&XmppChannel::OnSslError, task_ptr_factory_.GetWeakPtr())); } -void XmppChannel::OnTlsHandshakeComplete(std::unique_ptr<Stream> tls_stream) { - tls_stream_ = std::move(tls_stream); - stream_ = tls_stream_.get(); - state_ = XmppState::kTlsCompleted; +void XmppChannel::OnSslSocketReady(std::unique_ptr<Stream> stream) { + CHECK(XmppState::kConnecting == state_); + backoff_entry_.InformOfRequest(true); + stream_ = std::move(stream); + state_ = XmppState::kConnected; RestartXmppStream(); + ScheduleRegularPing(); } -void XmppChannel::OnTlsError(const Error* error) { +void XmppChannel::OnSslError(const Error* error) { LOG(ERROR) << "TLS handshake failed. Restarting XMPP connection"; - Restart(); + backoff_entry_.InformOfRequest(false); + + LOG(INFO) << "Delaying connection to XMPP server for " + << backoff_entry_.GetTimeUntilRelease(); + task_runner_->PostDelayedTask( + FROM_HERE, + base::Bind(&XmppChannel::CreateSslSocket, task_ptr_factory_.GetWeakPtr()), + backoff_entry_.GetTimeUntilRelease()); } void XmppChannel::SendMessage(const std::string& message) { @@ -336,10 +334,6 @@ void XmppChannel::OnMessageSent() { ErrorPtr error; write_pending_ = false; - if (!stream_->FlushBlocking(&error)) { - OnWriteError(error.get()); - return; - } if (queued_write_data_.empty()) { WaitForMessage(); } else { @@ -373,30 +367,6 @@ Restart(); } -void XmppChannel::Connect(const std::string& host, - uint16_t port, - const base::Closure& callback) { - state_ = XmppState::kConnecting; - LOG(INFO) << "Starting XMPP connection to " << host << ":" << port; - raw_socket_ = network_->OpenSocketBlocking(host, port); - - backoff_entry_.InformOfRequest(raw_socket_ != nullptr); - if (raw_socket_) { - host_ = host; - port_ = port; - stream_ = raw_socket_.get(); - callback.Run(); - } else { - LOG(INFO) << "Delaying connection to XMPP server " << host << " for " - << backoff_entry_.GetTimeUntilRelease(); - task_runner_->PostDelayedTask( - FROM_HERE, - base::Bind(&XmppChannel::Connect, task_ptr_factory_.GetWeakPtr(), host, - port, callback), - backoff_entry_.GetTimeUntilRelease()); - } -} - std::string XmppChannel::GetName() const { return "xmpp"; } @@ -419,9 +389,7 @@ CHECK(state_ == XmppState::kNotStarted); delegate_ = delegate; - Connect( - kDefaultXmppHost, kDefaultXmppPort, - base::Bind(&XmppChannel::OnConnected, task_ptr_factory_.GetWeakPtr())); + CreateSslSocket(); } void XmppChannel::Stop() { @@ -431,25 +399,10 @@ task_ptr_factory_.InvalidateWeakPtrs(); ping_ptr_factory_.InvalidateWeakPtrs(); - if (tls_stream_) { - tls_stream_->CloseBlocking(nullptr); - tls_stream_.reset(); - } - if (raw_socket_) { - raw_socket_->CloseBlocking(nullptr); - raw_socket_.reset(); - } - stream_ = nullptr; + stream_.reset(); state_ = XmppState::kNotStarted; } -void XmppChannel::OnConnected() { - CHECK(XmppState::kConnecting == state_); - state_ = XmppState::kConnected; - RestartXmppStream(); - ScheduleRegularPing(); -} - void XmppChannel::RestartXmppStream() { stream_parser_.Reset(); stream_->CancelPendingAsyncOperations();
diff --git a/libweave/src/notification/xmpp_channel.h b/libweave/src/notification/xmpp_channel.h index 112799f..ddbdee1 100644 --- a/libweave/src/notification/xmpp_channel.h +++ b/libweave/src/notification/xmpp_channel.h
@@ -61,8 +61,6 @@ kNotStarted, kConnecting, kConnected, - kTlsStarted, - kTlsCompleted, kAuthenticationStarted, kAuthenticationFailed, kStreamRestartedPostAuthentication, @@ -75,20 +73,13 @@ protected: // These methods are internal helpers that can be overloaded by unit tests // to help provide unit-test-specific functionality. - virtual void Connect(const std::string& host, - uint16_t port, - const base::Closure& callback); virtual void SchedulePing(base::TimeDelta interval, base::TimeDelta timeout); void ScheduleRegularPing(); void ScheduleFastPing(); - XmppState state_{XmppState::kNotStarted}; - - // The connection socket stream to the XMPP server. - Stream* stream_{nullptr}; - private: friend class IqStanzaHandler; + friend class FakeXmppChannel; // Overrides from XmppStreamParser::Delegate. void OnStreamStart(const std::string& node_name, @@ -103,13 +94,12 @@ void HandleMessageStanza(std::unique_ptr<XmlNode> stanza); void RestartXmppStream(); - void StartTlsHandshake(); - void OnTlsHandshakeComplete(std::unique_ptr<Stream> tls_stream); - void OnTlsError(const Error* error); + void CreateSslSocket(); + void OnSslSocketReady(std::unique_ptr<Stream> stream); + void OnSslError(const Error* error); void WaitForMessage(); - void OnConnected(); void OnMessageRead(size_t size); void OnMessageSent(); void OnReadError(const Error* error); @@ -130,6 +120,8 @@ void OnConnectivityChanged(bool online); + XmppState state_{XmppState::kNotStarted}; + // Robot account name for the device. std::string account_; @@ -140,8 +132,7 @@ std::string access_token_; Network* network_{nullptr}; - std::unique_ptr<Stream> raw_socket_; - std::unique_ptr<Stream> tls_stream_; // Must follow |raw_socket_|. + std::unique_ptr<Stream> stream_; // Read buffer for incoming message packets. std::vector<char> read_socket_data_;
diff --git a/libweave/src/notification/xmpp_channel_unittest.cc b/libweave/src/notification/xmpp_channel_unittest.cc index 574dfa2..d7c8174 100644 --- a/libweave/src/notification/xmpp_channel_unittest.cc +++ b/libweave/src/notification/xmpp_channel_unittest.cc
@@ -8,11 +8,16 @@ #include <queue> #include <gtest/gtest.h> +#include <weave/test/mock_network.h> #include <weave/test/mock_task_runner.h> #include "libweave/src/bind_lambda.h" +using testing::_; +using testing::Invoke; +using testing::Return; using testing::StrictMock; +using testing::WithArgs; namespace weave { @@ -21,16 +26,22 @@ constexpr char kAccountName[] = "Account@Name"; constexpr char kAccessToken[] = "AccessToken"; +constexpr char kStartStreamMessage[] = + "<stream:stream to='clouddevices.gserviceaccount.com' " + "xmlns:stream='http://etherx.jabber.org/streams' xml:lang='*' " + "version='1.0' xmlns='jabber:client'>"; constexpr char kStartStreamResponse[] = "<stream:stream from=\"clouddevices.gserviceaccount.com\" " "id=\"0CCF520913ABA04B\" version=\"1.0\" " "xmlns:stream=\"http://etherx.jabber.org/streams\" " - "xmlns=\"jabber:client\">" - "<stream:features><starttls xmlns=\"urn:ietf:params:xml:ns:xmpp-tls\">" - "<required/></starttls><mechanisms " - "xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\"><mechanism>X-OAUTH2</mechanism>" - "<mechanism>X-GOOGLE-TOKEN</mechanism></mechanisms></stream:features>"; -constexpr char kTlsStreamResponse[] = + "xmlns=\"jabber:client\">"; +constexpr char kAuthenticationMessage[] = + "<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='X-OAUTH2' " + "auth:service='oauth2' auth:allow-non-google-login='true' " + "auth:client-uses-full-bind-result='true' " + "xmlns:auth='http://www.google.com/talk/protocol/auth'>" + "AEFjY291bnRATmFtZQBBY2Nlc3NUb2tlbg==</auth>"; +constexpr char kConnectedResponse[] = "<stream:features><mechanisms xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\">" "<mechanism>X-OAUTH2</mechanism>" "<mechanism>X-GOOGLE-TOKEN</mechanism></mechanisms></stream:features>"; @@ -55,18 +66,6 @@ "19853128\" from=\"" "110cc78f78d7032cc7bf2c6e14c1fa7d@clouddevices.gserviceaccount.com\" " "id=\"3\" type=\"result\"/>"; -constexpr char kStartStreamMessage[] = - "<stream:stream to='clouddevices.gserviceaccount.com' " - "xmlns:stream='http://etherx.jabber.org/streams' xml:lang='*' " - "version='1.0' xmlns='jabber:client'>"; -constexpr char kStartTlsMessage[] = - "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>"; -constexpr char kAuthenticationMessage[] = - "<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='X-OAUTH2' " - "auth:service='oauth2' auth:allow-non-google-login='true' " - "auth:client-uses-full-bind-result='true' " - "xmlns:auth='http://www.google.com/talk/protocol/auth'>" - "AEFjY291bnRATmFtZQBBY2Nlc3NUb2tlbg==</auth>"; constexpr char kBindMessage[] = "<iq id='1' type='set'><bind " "xmlns='urn:ietf:params:xml:ns:xmpp-bind'/></iq>"; @@ -84,10 +83,6 @@ public: explicit FakeStream(TaskRunner* task_runner) : task_runner_{task_runner} {} - bool FlushBlocking(ErrorPtr* error) override { return true; } - - bool CloseBlocking(ErrorPtr* error) override { return true; } - void CancelPendingAsyncOperations() override {} void ExpectWritePacketString(base::TimeDelta, const std::string& data) { @@ -134,35 +129,55 @@ class FakeXmppChannel : public XmppChannel { public: - explicit FakeXmppChannel(TaskRunner* task_runner) - : XmppChannel{kAccountName, kAccessToken, task_runner, nullptr}, - fake_stream_{task_runner} {} + explicit FakeXmppChannel(TaskRunner* task_runner, weave::Network* network) + : XmppChannel{kAccountName, kAccessToken, task_runner, network}, + stream_{new FakeStream{task_runner_}}, + fake_stream_{stream_.get()} {} + + void Connect( + const base::Callback<void(std::unique_ptr<weave::Stream>)>& callback) { + callback.Run(std::move(stream_)); + } XmppState state() const { return state_; } void set_state(XmppState state) { state_ = state; } - void Connect(const std::string& host, - uint16_t port, - const base::Closure& callback) override { - set_state(XmppState::kConnecting); - stream_ = &fake_stream_; - callback.Run(); - } - void SchedulePing(base::TimeDelta interval, base::TimeDelta timeout) override {} - FakeStream fake_stream_; + void ExpectWritePacketString(base::TimeDelta delta, const std::string& data) { + fake_stream_->ExpectWritePacketString(delta, data); + } + + void AddReadPacketString(base::TimeDelta delta, const std::string& data) { + fake_stream_->AddReadPacketString(delta, data); + } + + std::unique_ptr<FakeStream> stream_; + FakeStream* fake_stream_{nullptr}; +}; + +class MockNetwork : public weave::test::MockNetwork { + public: + MockNetwork() { + EXPECT_CALL(*this, AddOnConnectionChangedCallback(_)) + .WillRepeatedly(Return()); + } }; class XmppChannelTest : public ::testing::Test { protected: + XmppChannelTest() { + EXPECT_CALL(network_, OpenSslSocket("talk.google.com", 5223, _, _)) + .WillOnce( + WithArgs<2>(Invoke(&xmpp_client_, &FakeXmppChannel::Connect))); + } + void StartStream() { - xmpp_client_.fake_stream_.ExpectWritePacketString({}, kStartStreamMessage); - xmpp_client_.fake_stream_.AddReadPacketString({}, kStartStreamResponse); - xmpp_client_.fake_stream_.ExpectWritePacketString({}, kStartTlsMessage); + xmpp_client_.ExpectWritePacketString({}, kStartStreamMessage); + xmpp_client_.AddReadPacketString({}, kStartStreamResponse); xmpp_client_.Start(nullptr); - RunUntil(XmppChannel::XmppState::kTlsStarted); + RunUntil(XmppChannel::XmppState::kConnected); } void StartWithState(XmppChannel::XmppState state) { @@ -177,12 +192,13 @@ } StrictMock<test::MockTaskRunner> task_runner_; - FakeXmppChannel xmpp_client_{&task_runner_}; + StrictMock<MockNetwork> network_; + FakeXmppChannel xmpp_client_{&task_runner_, &network_}; }; TEST_F(XmppChannelTest, StartStream) { EXPECT_EQ(XmppChannel::XmppState::kNotStarted, xmpp_client_.state()); - xmpp_client_.fake_stream_.ExpectWritePacketString({}, kStartStreamMessage); + xmpp_client_.ExpectWritePacketString({}, kStartStreamMessage); xmpp_client_.Start(nullptr); RunUntil(XmppChannel::XmppState::kConnected); } @@ -191,48 +207,46 @@ StartStream(); } -TEST_F(XmppChannelTest, HandleTLSCompleted) { - StartWithState(XmppChannel::XmppState::kTlsCompleted); - xmpp_client_.fake_stream_.AddReadPacketString({}, kTlsStreamResponse); - xmpp_client_.fake_stream_.ExpectWritePacketString({}, kAuthenticationMessage); +TEST_F(XmppChannelTest, HandleConnected) { + StartWithState(XmppChannel::XmppState::kConnected); + xmpp_client_.AddReadPacketString({}, kConnectedResponse); + xmpp_client_.ExpectWritePacketString({}, kAuthenticationMessage); RunUntil(XmppChannel::XmppState::kAuthenticationStarted); } TEST_F(XmppChannelTest, HandleAuthenticationSucceededResponse) { StartWithState(XmppChannel::XmppState::kAuthenticationStarted); - xmpp_client_.fake_stream_.AddReadPacketString( - {}, kAuthenticationSucceededResponse); - xmpp_client_.fake_stream_.ExpectWritePacketString({}, kStartStreamMessage); + xmpp_client_.AddReadPacketString({}, kAuthenticationSucceededResponse); + xmpp_client_.ExpectWritePacketString({}, kStartStreamMessage); RunUntil(XmppChannel::XmppState::kStreamRestartedPostAuthentication); } TEST_F(XmppChannelTest, HandleAuthenticationFailedResponse) { StartWithState(XmppChannel::XmppState::kAuthenticationStarted); - xmpp_client_.fake_stream_.AddReadPacketString({}, - kAuthenticationFailedResponse); + xmpp_client_.AddReadPacketString({}, kAuthenticationFailedResponse); RunUntil(XmppChannel::XmppState::kAuthenticationFailed); } TEST_F(XmppChannelTest, HandleStreamRestartedResponse) { StartWithState(XmppChannel::XmppState::kStreamRestartedPostAuthentication); - xmpp_client_.fake_stream_.AddReadPacketString({}, kRestartStreamResponse); - xmpp_client_.fake_stream_.ExpectWritePacketString({}, kBindMessage); + xmpp_client_.AddReadPacketString({}, kRestartStreamResponse); + xmpp_client_.ExpectWritePacketString({}, kBindMessage); RunUntil(XmppChannel::XmppState::kBindSent); EXPECT_TRUE(xmpp_client_.jid().empty()); - xmpp_client_.fake_stream_.AddReadPacketString({}, kBindResponse); - xmpp_client_.fake_stream_.ExpectWritePacketString({}, kSessionMessage); + xmpp_client_.AddReadPacketString({}, kBindResponse); + xmpp_client_.ExpectWritePacketString({}, kSessionMessage); RunUntil(XmppChannel::XmppState::kSessionStarted); EXPECT_EQ( "110cc78f78d7032cc7bf2c6e14c1fa7d@clouddevices.gserviceaccount.com" "/19853128", xmpp_client_.jid()); - xmpp_client_.fake_stream_.AddReadPacketString({}, kSessionResponse); - xmpp_client_.fake_stream_.ExpectWritePacketString({}, kSubscribeMessage); + xmpp_client_.AddReadPacketString({}, kSessionResponse); + xmpp_client_.ExpectWritePacketString({}, kSubscribeMessage); RunUntil(XmppChannel::XmppState::kSubscribeStarted); - xmpp_client_.fake_stream_.AddReadPacketString({}, kSubscribedResponse); + xmpp_client_.AddReadPacketString({}, kSubscribedResponse); RunUntil(XmppChannel::XmppState::kSubscribed); }