buffet: Add TLS support to XMPP connection

Properly handle <starttls> XMPP request from the server by initiating
TLS handshake.

BUG=brillo:191
TEST=`FEATURES=test emerge-link buffet`
     Manually test on LINK DUT and inspecting logs to make sure TLS
     connection to XMPP server is established successfully.

Change-Id: I94d8b5eb9e29402fc3d662afcfdbf2e0c5ec1a02
Reviewed-on: https://chromium-review.googlesource.com/272263
Tested-by: Alex Vakulenko <avakulenko@chromium.org>
Reviewed-by: Vitaly Buka <vitalybuka@chromium.org>
Commit-Queue: Alex Vakulenko <avakulenko@chromium.org>
diff --git a/buffet/notification/xmpp_channel.cc b/buffet/notification/xmpp_channel.cc
index 4567408..3674184 100644
--- a/buffet/notification/xmpp_channel.cc
+++ b/buffet/notification/xmpp_channel.cc
@@ -10,6 +10,7 @@
 #include <chromeos/backoff_entry.h>
 #include <chromeos/data_encoding.h>
 #include <chromeos/streams/file_stream.h>
+#include <chromeos/streams/tls_stream.h>
 
 #include "buffet/notification/notification_delegate.h"
 #include "buffet/notification/xml_node.h"
@@ -137,9 +138,22 @@
 void XmppChannel::HandleStanza(std::unique_ptr<XmlNode> stanza) {
   VLOG(2) << "XMPP stanza received: " << stanza->ToString();
 
-  // TODO(nathanbullock): Need to add support for TLS (brillo:191).
   switch (state_) {
     case XmppState::kStarted:
+      if (stanza->name() == "stream:features" &&
+          stanza->FindFirstChild("starttls/required", false)) {
+        state_ = XmppState::kTlsStarted;
+        SendMessage("<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>");
+        return;
+      }
+      break;
+    case XmppState::kTlsStarted:
+      if (stanza->name() == "proceed") {
+        StartTlsHandshake();
+        return;
+      }
+      break;
+    case XmppState::kTlsCompleted:
       if (stanza->name() == "stream:features") {
         auto children = stanza->FindChildren("mechanisms/mechanism", false);
         for (const auto& child : children) {
@@ -208,6 +222,28 @@
   SendMessage("</stream:stream>");
 }
 
+void XmppChannel::StartTlsHandshake() {
+  stream_->CancelPendingAsyncOperations();
+  chromeos::TlsStream::Connect(
+      std::move(raw_socket_), host_,
+      base::Bind(&XmppChannel::OnTlsHandshakeComplete,
+                 weak_ptr_factory_.GetWeakPtr()),
+      base::Bind(&XmppChannel::OnTlsError,
+                 weak_ptr_factory_.GetWeakPtr()));
+}
+
+void XmppChannel::OnTlsHandshakeComplete(chromeos::StreamPtr tls_stream) {
+  tls_stream_ = std::move(tls_stream);
+  stream_ = tls_stream_.get();
+  state_ = XmppState::kTlsCompleted;
+  RestartXmppStream();
+}
+
+void XmppChannel::OnTlsError(const chromeos::Error* error) {
+  LOG(ERROR) << "TLS handshake failed. Restarting XMPP connection";
+  Restart();
+}
+
 void XmppChannel::SendMessage(const std::string& message) {
   if (write_pending_) {
     queued_write_data_ += message;
@@ -285,6 +321,8 @@
 
   backoff_entry_.InformOfRequest(raw_socket_ != nullptr);
   if (raw_socket_) {
+    host_ = host;
+    port_ = port;
     stream_ = raw_socket_.get();
     callback.Run();
   } else {
@@ -325,6 +363,10 @@
     delegate_->OnDisconnected();
 
   weak_ptr_factory_.InvalidateWeakPtrs();
+  if (tls_stream_) {
+    tls_stream_->CloseBlocking(nullptr);
+    tls_stream_.reset();
+  }
   if (raw_socket_) {
     raw_socket_->CloseBlocking(nullptr);
     raw_socket_.reset();
diff --git a/buffet/notification/xmpp_channel.h b/buffet/notification/xmpp_channel.h
index d6b1550..4cf540e 100644
--- a/buffet/notification/xmpp_channel.h
+++ b/buffet/notification/xmpp_channel.h
@@ -43,6 +43,8 @@
   enum class XmppState {
     kNotStarted,
     kStarted,
+    kTlsStarted,
+    kTlsCompleted,
     kAuthenticationStarted,
     kAuthenticationFailed,
     kStreamRestartedPostAuthentication,
@@ -71,6 +73,10 @@
   void HandleStanza(std::unique_ptr<XmlNode> stanza);
   void RestartXmppStream();
 
+  void StartTlsHandshake();
+  void OnTlsHandshakeComplete(chromeos::StreamPtr tls_stream);
+  void OnTlsError(const chromeos::Error* error);
+
   void SendMessage(const std::string& message);
   void WaitForMessage();
 
@@ -88,6 +94,7 @@
   std::string access_token_;
 
   chromeos::StreamPtr raw_socket_;
+  chromeos::StreamPtr tls_stream_;
 
   // Read buffer for incoming message packets.
   std::vector<char> read_socket_data_;
@@ -95,6 +102,10 @@
   std::string write_socket_data_;
   std::string queued_write_data_;
 
+  // XMPP server name and port used for connection.
+  std::string host_;
+  uint16_t port_{0};
+
   chromeos::BackoffEntry backoff_entry_;
   NotificationDelegate* delegate_{nullptr};
   scoped_refptr<base::TaskRunner> task_runner_;
diff --git a/buffet/notification/xmpp_channel_unittest.cc b/buffet/notification/xmpp_channel_unittest.cc
index 7dbfd7a..6e4f4b1 100644
--- a/buffet/notification/xmpp_channel_unittest.cc
+++ b/buffet/notification/xmpp_channel_unittest.cc
@@ -33,6 +33,10 @@
     "<required/></starttls><mechanisms "
     "xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\"><mechanism>X-OAUTH2</mechanism>"
     "<mechanism>X-GOOGLE-TOKEN</mechanism></mechanisms></stream:features>";
+constexpr char kTlsStreamResponse[] =
+    "<stream:features><mechanisms xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\">"
+    "<mechanism>X-OAUTH2</mechanism>"
+    "<mechanism>X-GOOGLE-TOKEN</mechanism></mechanisms></stream:features>";
 constexpr char kAuthenticationSucceededResponse[] =
     "<success xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\"/>";
 constexpr char kAuthenticationFailedResponse[] =
@@ -59,6 +63,8 @@
     "<stream:stream to='clouddevices.gserviceaccount.com' "
     "xmlns:stream='http://etherx.jabber.org/streams' xml:lang='*' "
     "version='1.0' xmlns='jabber:client'>";
+constexpr char kStartTlsMessage[] =
+    "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>";
 constexpr char kAuthenticationMessage[] =
     "<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='X-OAUTH2' "
     "auth:service='oauth2' auth:allow-non-google-login='true' "
@@ -130,8 +136,7 @@
   void StartStream() {
     xmpp_client_->fake_stream_.ExpectWritePacketString({}, kStartStreamMessage);
     xmpp_client_->fake_stream_.AddReadPacketString({}, kStartStreamResponse);
-    xmpp_client_->fake_stream_.ExpectWritePacketString({},
-                                                       kAuthenticationMessage);
+    xmpp_client_->fake_stream_.ExpectWritePacketString({}, kStartTlsMessage);
     xmpp_client_->Start(nullptr);
     RunTasks(4);
   }
@@ -166,6 +171,16 @@
 
 TEST_F(XmppChannelTest, HandleStartedResponse) {
   StartStream();
+  EXPECT_EQ(XmppChannel::XmppState::kTlsStarted, xmpp_client_->state());
+}
+
+TEST_F(XmppChannelTest, HandleTLSCompleted) {
+  StartWithState(XmppChannel::XmppState::kTlsCompleted);
+  xmpp_client_->fake_stream_.AddReadPacketString(
+      {}, kTlsStreamResponse);
+  xmpp_client_->fake_stream_.ExpectWritePacketString({},
+                                                     kAuthenticationMessage);
+  RunTasks(4);
   EXPECT_EQ(XmppChannel::XmppState::kAuthenticationStarted,
             xmpp_client_->state());
 }