Switch from starttls to SSL sockets

SSL sockets simplifies both libweave and device side.
Also removed Flush and Close method, as not needed with current design.

BUG=b:23529805

Change-Id: I6e2a5522f8a55d20b0ed6363a9955d37505d2c06
diff --git a/libweave/examples/ubuntu/network_manager.cc b/libweave/examples/ubuntu/network_manager.cc
index 48c817c..ab1d521 100644
--- a/libweave/examples/ubuntu/network_manager.cc
+++ b/libweave/examples/ubuntu/network_manager.cc
@@ -39,160 +39,11 @@
   NOTREACHED();
 }
 
-class SocketStream : public Stream {
- public:
-  explicit SocketStream(TaskRunner* task_runner) : task_runner_{task_runner} {}
-
-  ~SocketStream() { CloseBlocking(nullptr); }
-
-  void RunDelayedTask(const base::Closure& success_callback) {
-    success_callback.Run();
-  }
-
-  bool ReadAsync(void* buffer,
-                 size_t size_to_read,
-                 const base::Callback<void(size_t)>& success_callback,
-                 const base::Callback<void(const Error*)>& error_callback,
-                 ErrorPtr* error) {
-    if (socket_fd_ < 0) {
-      Error::AddTo(error, FROM_HERE, "socket", "invalid_socket",
-                   strerror(errno));
-      return false;
-    }
-    int size_read = recv(socket_fd_, buffer, size_to_read, MSG_DONTWAIT);
-    if (size_read > 0) {
-      task_runner_->PostDelayedTask(
-          FROM_HERE, base::Bind(&SocketStream::RunDelayedTask,
-                                weak_ptr_factory_.GetWeakPtr(),
-                                base::Bind(success_callback, size_read)),
-          {});
-      return true;
-    }
-    if (errno == EAGAIN || errno == EWOULDBLOCK) {
-      task_runner_->PostDelayedTask(
-          FROM_HERE,
-          base::Bind(base::IgnoreResult(&SocketStream::ReadAsync),
-                     weak_ptr_factory_.GetWeakPtr(), buffer, size_to_read,
-                     success_callback, error_callback, nullptr),
-          base::TimeDelta::FromMilliseconds(200));
-      return true;
-    }
-
-    ErrorPtr recv_error;
-    Error::AddTo(&recv_error, FROM_HERE, "socket", "socket_recv_failed",
-                 strerror(errno));
-    task_runner_->PostDelayedTask(
-        FROM_HERE,
-        base::Bind(error_callback, base::Owned(recv_error.release())), {});
-    return true;
-  }
-
-  bool WriteAllAsync(const void* buffer,
-                     size_t size_to_write,
-                     const base::Closure& success_callback,
-                     const base::Callback<void(const Error*)>& error_callback,
-                     ErrorPtr* error) {
-    if (socket_fd_ < 0) {
-      Error::AddTo(error, FROM_HERE, "socket", "invalid_socket",
-                   strerror(errno));
-      return false;
-    }
-    const char* buffer_ptr = static_cast<const char*>(buffer);
-    do {
-      int size_sent = send(socket_fd_, buffer_ptr, size_to_write, 0);
-      if (size_sent <= 0) {
-        ErrorPtr send_error;
-        Error::AddTo(&send_error, FROM_HERE, "socket", "socket_send_failed",
-                     strerror(errno));
-        task_runner_->PostDelayedTask(
-            FROM_HERE,
-            base::Bind(error_callback, base::Owned(send_error.release())), {});
-        // Still true as we return error with callback.
-        return true;
-      }
-      size_to_write -= size_sent;
-      buffer_ptr += size_sent;
-    } while (size_to_write > 0);
-
-    task_runner_->PostDelayedTask(FROM_HERE, success_callback, {});
-    return true;
-  }
-  bool FlushBlocking(ErrorPtr* error) { return true; }
-
-  bool CloseBlocking(ErrorPtr* error) {
-    weak_ptr_factory_.InvalidateWeakPtrs();
-    if (socket_fd_ >= 0) {
-      close(socket_fd_);
-      socket_fd_ = -1;
-    }
-  }
-
-  void CancelPendingAsyncOperations() {
-    weak_ptr_factory_.InvalidateWeakPtrs();
-  }
-
-  bool Connect(const std::string& host, uint16_t port) {
-    std::string service = std::to_string(port);
-    addrinfo hints = {0, AF_UNSPEC, SOCK_STREAM};
-    addrinfo* result = nullptr;
-    if (getaddrinfo(host.c_str(), service.c_str(), &hints, &result)) {
-      LOG(ERROR) << "Failed to resolve host name: " << host;
-      return false;
-    }
-    std::unique_ptr<addrinfo, decltype(&freeaddrinfo)> result_deleter{
-        result, &freeaddrinfo};
-
-    for (const addrinfo* info = result; info != nullptr; info = info->ai_next) {
-      socket_fd_ =
-          socket(info->ai_family, info->ai_socktype, info->ai_protocol);
-      if (socket_fd_ < 0)
-        continue;
-
-      int flags = fcntl(socket_fd_, F_GETFL, 0);
-      if (flags == -1)
-        flags = 0;
-      fcntl(socket_fd_, F_SETFL, flags | O_NONBLOCK);
-
-      LOG(INFO) << "Connecting...";
-      if (connect(socket_fd_, info->ai_addr, info->ai_addrlen) == 0)
-        break;  // Success.
-
-      if (errno == EINPROGRESS) {
-        fd_set write_fds;
-        FD_ZERO(&write_fds);
-        FD_SET(socket_fd_, &write_fds);
-
-        struct timeval tv;
-        tv.tv_sec = 5;
-        tv.tv_usec = 0;
-
-        int select_ret = select(socket_fd_ + 1, NULL, &write_fds, NULL, &tv);
-        if (select_ret != -1 && select_ret != 0) {
-          break;
-        }
-      }
-
-      LOG(ERROR) << "Failed to connect";
-      CloseBlocking(nullptr);
-    }
-
-    return socket_fd_ >= 0;
-  }
-
-  int GetFd() const { return socket_fd_; }
-
- private:
-  TaskRunner* task_runner_{nullptr};
-  int socket_fd_{-1};
-
-  base::WeakPtrFactory<SocketStream> weak_ptr_factory_{this};
-};
-
 class SSLStream : public Stream {
  public:
   explicit SSLStream(TaskRunner* task_runner) : task_runner_{task_runner} {}
 
-  ~SSLStream() { weak_ptr_factory_.InvalidateWeakPtrs(); }
+  ~SSLStream() { CancelPendingAsyncOperations(); }
 
   void RunDelayedTask(const base::Closure& success_callback) {
     success_callback.Run();
@@ -289,33 +140,27 @@
     return true;
   }
 
-  bool FlushBlocking(ErrorPtr* error) { return true; }
-
-  bool CloseBlocking(ErrorPtr* error) {
-    weak_ptr_factory_.InvalidateWeakPtrs();
-    return true;
-  }
-
   void CancelPendingAsyncOperations() {
     weak_ptr_factory_.InvalidateWeakPtrs();
   }
 
-  bool Init() {
+  bool Init(const std::string& host, uint16_t port) {
     ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
     CHECK(ctx_);
     ssl_.reset(SSL_new(ctx_.get()));
 
-    char endpoint[] = "talk.google.com:5223";
-    stream_bio_ = BIO_new_connect(endpoint);
-    CHECK(stream_bio_);
-    BIO_set_nbio(stream_bio_, 1);
+    char end_point[255];
+    snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port);
+    BIO* stream_bio = BIO_new_connect(end_point);
+    CHECK(stream_bio);
+    BIO_set_nbio(stream_bio, 1);
 
-    while (BIO_do_connect(stream_bio_) != 1) {
-      CHECK(BIO_should_retry(stream_bio_));
+    while (BIO_do_connect(stream_bio) != 1) {
+      CHECK(BIO_should_retry(stream_bio));
       sleep(1);
     }
 
-    SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_);
+    SSL_set_bio(ssl_.get(), stream_bio, stream_bio);
     SSL_set_connect_state(ssl_.get());
 
     for (;;) {
@@ -339,7 +184,6 @@
   TaskRunner* task_runner_{nullptr};
   std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free};
   std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free};
-  BIO* stream_bio_{nullptr};
 
   base::WeakPtrFactory<SSLStream> weak_ptr_factory_{this};
 };
@@ -504,23 +348,15 @@
     i.Run(online);
 }
 
-std::unique_ptr<Stream> NetworkImpl::OpenSocketBlocking(const std::string& host,
-                                                        uint16_t port) {
-  std::unique_ptr<SocketStream> stream{new SocketStream{task_runner_}};
-  if (!stream->Connect(host, port))
-    return nullptr;
-  return std::move(stream);
-}
-
-void NetworkImpl::CreateTlsStream(
-    std::unique_ptr<Stream> stream,
+void NetworkImpl::OpenSslSocket(
     const std::string& host,
+    uint16_t port,
     const base::Callback<void(std::unique_ptr<Stream>)>& success_callback,
     const base::Callback<void(const Error*)>& error_callback) {
   // Connect to SSL port instead of upgrading to TLS.
   std::unique_ptr<SSLStream> tls_stream{new SSLStream{task_runner_}};
 
-  if (tls_stream->Init()) {
+  if (tls_stream->Init(host, port)) {
     task_runner_->PostDelayedTask(
         FROM_HERE, base::Bind(success_callback, base::Passed(&tls_stream)), {});
   } else {
diff --git a/libweave/examples/ubuntu/network_manager.h b/libweave/examples/ubuntu/network_manager.h
index c58d220..c9d6ff4 100644
--- a/libweave/examples/ubuntu/network_manager.h
+++ b/libweave/examples/ubuntu/network_manager.h
@@ -34,11 +34,9 @@
   NetworkState GetConnectionState() const override;
   void EnableAccessPoint(const std::string& ssid) override;
   void DisableAccessPoint() override;
-  std::unique_ptr<Stream> OpenSocketBlocking(const std::string& host,
-                                             uint16_t port) override;
-  void CreateTlsStream(
-      std::unique_ptr<Stream> stream,
+  void OpenSslSocket(
       const std::string& host,
+      uint16_t port,
       const base::Callback<void(std::unique_ptr<Stream>)>& success_callback,
       const base::Callback<void(const Error*)>& error_callback) override;
 
diff --git a/libweave/include/weave/network.h b/libweave/include/weave/network.h
index 41e5413..0f0b12b 100644
--- a/libweave/include/weave/network.h
+++ b/libweave/include/weave/network.h
@@ -48,14 +48,9 @@
   virtual void DisableAccessPoint() = 0;
 
   // Opens bidirectional sockets and returns attached stream.
-  // TODO(vitalybuka): Make async.
-  virtual std::unique_ptr<Stream> OpenSocketBlocking(const std::string& host,
-                                                     uint16_t port) = 0;
-
-  // Replaces stream with version with TLS support.
-  virtual void CreateTlsStream(
-      std::unique_ptr<Stream> socket,
+  virtual void OpenSslSocket(
       const std::string& host,
+      uint16_t port,
       const base::Callback<void(std::unique_ptr<Stream>)>& success_callback,
       const base::Callback<void(const Error*)>& error_callback) = 0;
 
diff --git a/libweave/include/weave/stream.h b/libweave/include/weave/stream.h
index 3ecdba6..9d4e657 100644
--- a/libweave/include/weave/stream.h
+++ b/libweave/include/weave/stream.h
@@ -30,10 +30,6 @@
       const base::Callback<void(const Error*)>& error_callback,
       ErrorPtr* error) = 0;
 
-  virtual bool FlushBlocking(ErrorPtr* error) = 0;
-
-  virtual bool CloseBlocking(ErrorPtr* error) = 0;
-
   virtual void CancelPendingAsyncOperations() = 0;
 };
 
diff --git a/libweave/include/weave/test/mock_network.h b/libweave/include/weave/test/mock_network.h
index f09533e..f188223 100644
--- a/libweave/include/weave/test/mock_network.h
+++ b/libweave/include/weave/test/mock_network.h
@@ -30,22 +30,11 @@
   MOCK_METHOD1(EnableAccessPoint, void(const std::string&));
   MOCK_METHOD0(DisableAccessPoint, void());
 
-  MOCK_METHOD2(MockOpenSocketBlocking, Stream*(const std::string&, uint16_t));
-  MOCK_METHOD2(MockCreateTlsStream, Stream*(Stream*, const std::string&));
-
-  std::unique_ptr<Stream> OpenSocketBlocking(const std::string& host,
-                                             uint16_t port) override {
-    return std::unique_ptr<Stream>{MockOpenSocketBlocking(host, port)};
-  }
-
-  void CreateTlsStream(
-      std::unique_ptr<Stream> socket,
-      const std::string& host,
-      const base::Callback<void(std::unique_ptr<Stream>)>& success_callback,
-      const base::Callback<void(const Error*)>& error_callback) override {
-    success_callback.Run(
-        std::unique_ptr<Stream>{MockCreateTlsStream(socket.get(), host)});
-  }
+  MOCK_METHOD4(OpenSslSocket,
+               void(const std::string&,
+                    uint16_t,
+                    const base::Callback<void(std::unique_ptr<Stream>)>&,
+                    const base::Callback<void(const Error*)>&));
 };
 
 }  // namespace test
diff --git a/libweave/src/notification/xmpp_channel.cc b/libweave/src/notification/xmpp_channel.cc
index dc9f851..ad4e233 100644
--- a/libweave/src/notification/xmpp_channel.cc
+++ b/libweave/src/notification/xmpp_channel.cc
@@ -75,7 +75,7 @@
 };
 
 const char kDefaultXmppHost[] = "talk.google.com";
-const uint16_t kDefaultXmppPort = 5222;
+const uint16_t kDefaultXmppPort = 5223;
 
 // Used for keeping connection alive.
 const int kRegularPingIntervalSeconds = 60;
@@ -152,20 +152,6 @@
 
   switch (state_) {
     case XmppState::kConnected:
-      if (stanza->name() == "stream:features" &&
-          stanza->FindFirstChild("starttls/required", false)) {
-        state_ = XmppState::kTlsStarted;
-        SendMessage("<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>");
-        return;
-      }
-      break;
-    case XmppState::kTlsStarted:
-      if (stanza->name() == "proceed") {
-        StartTlsHandshake();
-        return;
-      }
-      break;
-    case XmppState::kTlsCompleted:
       if (stanza->name() == "stream:features") {
         auto children = stanza->FindChildren("mechanisms/mechanism", false);
         for (const auto& child : children) {
@@ -217,7 +203,7 @@
   }
   // Something bad happened. Close the stream and start over.
   LOG(ERROR) << "Error condition occurred handling stanza: "
-             << stanza->ToString();
+             << stanza->ToString() << " in state: " << static_cast<int>(state_);
   CloseStream();
 }
 
@@ -290,25 +276,37 @@
     ParseNotificationJson(*json_dict, delegate_);
 }
 
-void XmppChannel::StartTlsHandshake() {
-  raw_socket_->CancelPendingAsyncOperations();
-  network_->CreateTlsStream(
-      std::move(raw_socket_), host_,
-      base::Bind(&XmppChannel::OnTlsHandshakeComplete,
+void XmppChannel::CreateSslSocket() {
+  CHECK(!stream_);
+  state_ = XmppState::kConnecting;
+  LOG(INFO) << "Starting XMPP connection to " << kDefaultXmppHost << ":" << kDefaultXmppPort;
+
+  network_->OpenSslSocket(
+      kDefaultXmppHost, kDefaultXmppPort,
+      base::Bind(&XmppChannel::OnSslSocketReady,
                  task_ptr_factory_.GetWeakPtr()),
-      base::Bind(&XmppChannel::OnTlsError, task_ptr_factory_.GetWeakPtr()));
+      base::Bind(&XmppChannel::OnSslError, task_ptr_factory_.GetWeakPtr()));
 }
 
-void XmppChannel::OnTlsHandshakeComplete(std::unique_ptr<Stream> tls_stream) {
-  tls_stream_ = std::move(tls_stream);
-  stream_ = tls_stream_.get();
-  state_ = XmppState::kTlsCompleted;
+void XmppChannel::OnSslSocketReady(std::unique_ptr<Stream> stream) {
+  CHECK(XmppState::kConnecting == state_);
+  backoff_entry_.InformOfRequest(true);
+  stream_ = std::move(stream);
+  state_ = XmppState::kConnected;
   RestartXmppStream();
+  ScheduleRegularPing();
 }
 
-void XmppChannel::OnTlsError(const Error* error) {
+void XmppChannel::OnSslError(const Error* error) {
   LOG(ERROR) << "TLS handshake failed. Restarting XMPP connection";
-  Restart();
+  backoff_entry_.InformOfRequest(false);
+
+  LOG(INFO) << "Delaying connection to XMPP server for "
+            << backoff_entry_.GetTimeUntilRelease();
+  task_runner_->PostDelayedTask(
+      FROM_HERE,
+      base::Bind(&XmppChannel::CreateSslSocket, task_ptr_factory_.GetWeakPtr()),
+      backoff_entry_.GetTimeUntilRelease());
 }
 
 void XmppChannel::SendMessage(const std::string& message) {
@@ -336,10 +334,6 @@
 void XmppChannel::OnMessageSent() {
   ErrorPtr error;
   write_pending_ = false;
-  if (!stream_->FlushBlocking(&error)) {
-    OnWriteError(error.get());
-    return;
-  }
   if (queued_write_data_.empty()) {
     WaitForMessage();
   } else {
@@ -373,30 +367,6 @@
   Restart();
 }
 
-void XmppChannel::Connect(const std::string& host,
-                          uint16_t port,
-                          const base::Closure& callback) {
-  state_ = XmppState::kConnecting;
-  LOG(INFO) << "Starting XMPP connection to " << host << ":" << port;
-  raw_socket_ = network_->OpenSocketBlocking(host, port);
-
-  backoff_entry_.InformOfRequest(raw_socket_ != nullptr);
-  if (raw_socket_) {
-    host_ = host;
-    port_ = port;
-    stream_ = raw_socket_.get();
-    callback.Run();
-  } else {
-    LOG(INFO) << "Delaying connection to XMPP server " << host << " for "
-              << backoff_entry_.GetTimeUntilRelease();
-    task_runner_->PostDelayedTask(
-        FROM_HERE,
-        base::Bind(&XmppChannel::Connect, task_ptr_factory_.GetWeakPtr(), host,
-                   port, callback),
-        backoff_entry_.GetTimeUntilRelease());
-  }
-}
-
 std::string XmppChannel::GetName() const {
   return "xmpp";
 }
@@ -419,9 +389,7 @@
   CHECK(state_ == XmppState::kNotStarted);
   delegate_ = delegate;
 
-  Connect(
-      kDefaultXmppHost, kDefaultXmppPort,
-      base::Bind(&XmppChannel::OnConnected, task_ptr_factory_.GetWeakPtr()));
+  CreateSslSocket();
 }
 
 void XmppChannel::Stop() {
@@ -431,25 +399,10 @@
   task_ptr_factory_.InvalidateWeakPtrs();
   ping_ptr_factory_.InvalidateWeakPtrs();
 
-  if (tls_stream_) {
-    tls_stream_->CloseBlocking(nullptr);
-    tls_stream_.reset();
-  }
-  if (raw_socket_) {
-    raw_socket_->CloseBlocking(nullptr);
-    raw_socket_.reset();
-  }
-  stream_ = nullptr;
+  stream_.reset();
   state_ = XmppState::kNotStarted;
 }
 
-void XmppChannel::OnConnected() {
-  CHECK(XmppState::kConnecting == state_);
-  state_ = XmppState::kConnected;
-  RestartXmppStream();
-  ScheduleRegularPing();
-}
-
 void XmppChannel::RestartXmppStream() {
   stream_parser_.Reset();
   stream_->CancelPendingAsyncOperations();
diff --git a/libweave/src/notification/xmpp_channel.h b/libweave/src/notification/xmpp_channel.h
index 112799f..ddbdee1 100644
--- a/libweave/src/notification/xmpp_channel.h
+++ b/libweave/src/notification/xmpp_channel.h
@@ -61,8 +61,6 @@
     kNotStarted,
     kConnecting,
     kConnected,
-    kTlsStarted,
-    kTlsCompleted,
     kAuthenticationStarted,
     kAuthenticationFailed,
     kStreamRestartedPostAuthentication,
@@ -75,20 +73,13 @@
  protected:
   // These methods are internal helpers that can be overloaded by unit tests
   // to help provide unit-test-specific functionality.
-  virtual void Connect(const std::string& host,
-                       uint16_t port,
-                       const base::Closure& callback);
   virtual void SchedulePing(base::TimeDelta interval, base::TimeDelta timeout);
   void ScheduleRegularPing();
   void ScheduleFastPing();
 
-  XmppState state_{XmppState::kNotStarted};
-
-  // The connection socket stream to the XMPP server.
-  Stream* stream_{nullptr};
-
  private:
   friend class IqStanzaHandler;
+  friend class FakeXmppChannel;
 
   // Overrides from XmppStreamParser::Delegate.
   void OnStreamStart(const std::string& node_name,
@@ -103,13 +94,12 @@
   void HandleMessageStanza(std::unique_ptr<XmlNode> stanza);
   void RestartXmppStream();
 
-  void StartTlsHandshake();
-  void OnTlsHandshakeComplete(std::unique_ptr<Stream> tls_stream);
-  void OnTlsError(const Error* error);
+  void CreateSslSocket();
+  void OnSslSocketReady(std::unique_ptr<Stream> stream);
+  void OnSslError(const Error* error);
 
   void WaitForMessage();
 
-  void OnConnected();
   void OnMessageRead(size_t size);
   void OnMessageSent();
   void OnReadError(const Error* error);
@@ -130,6 +120,8 @@
 
   void OnConnectivityChanged(bool online);
 
+  XmppState state_{XmppState::kNotStarted};
+
   // Robot account name for the device.
   std::string account_;
 
@@ -140,8 +132,7 @@
   std::string access_token_;
 
   Network* network_{nullptr};
-  std::unique_ptr<Stream> raw_socket_;
-  std::unique_ptr<Stream> tls_stream_;  // Must follow |raw_socket_|.
+  std::unique_ptr<Stream> stream_;
 
   // Read buffer for incoming message packets.
   std::vector<char> read_socket_data_;
diff --git a/libweave/src/notification/xmpp_channel_unittest.cc b/libweave/src/notification/xmpp_channel_unittest.cc
index 574dfa2..d7c8174 100644
--- a/libweave/src/notification/xmpp_channel_unittest.cc
+++ b/libweave/src/notification/xmpp_channel_unittest.cc
@@ -8,11 +8,16 @@
 #include <queue>
 
 #include <gtest/gtest.h>
+#include <weave/test/mock_network.h>
 #include <weave/test/mock_task_runner.h>
 
 #include "libweave/src/bind_lambda.h"
 
+using testing::_;
+using testing::Invoke;
+using testing::Return;
 using testing::StrictMock;
+using testing::WithArgs;
 
 namespace weave {
 
@@ -21,16 +26,22 @@
 constexpr char kAccountName[] = "Account@Name";
 constexpr char kAccessToken[] = "AccessToken";
 
+constexpr char kStartStreamMessage[] =
+    "<stream:stream to='clouddevices.gserviceaccount.com' "
+    "xmlns:stream='http://etherx.jabber.org/streams' xml:lang='*' "
+    "version='1.0' xmlns='jabber:client'>";
 constexpr char kStartStreamResponse[] =
     "<stream:stream from=\"clouddevices.gserviceaccount.com\" "
     "id=\"0CCF520913ABA04B\" version=\"1.0\" "
     "xmlns:stream=\"http://etherx.jabber.org/streams\" "
-    "xmlns=\"jabber:client\">"
-    "<stream:features><starttls xmlns=\"urn:ietf:params:xml:ns:xmpp-tls\">"
-    "<required/></starttls><mechanisms "
-    "xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\"><mechanism>X-OAUTH2</mechanism>"
-    "<mechanism>X-GOOGLE-TOKEN</mechanism></mechanisms></stream:features>";
-constexpr char kTlsStreamResponse[] =
+    "xmlns=\"jabber:client\">";
+constexpr char kAuthenticationMessage[] =
+    "<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='X-OAUTH2' "
+    "auth:service='oauth2' auth:allow-non-google-login='true' "
+    "auth:client-uses-full-bind-result='true' "
+    "xmlns:auth='http://www.google.com/talk/protocol/auth'>"
+    "AEFjY291bnRATmFtZQBBY2Nlc3NUb2tlbg==</auth>";
+constexpr char kConnectedResponse[] =
     "<stream:features><mechanisms xmlns=\"urn:ietf:params:xml:ns:xmpp-sasl\">"
     "<mechanism>X-OAUTH2</mechanism>"
     "<mechanism>X-GOOGLE-TOKEN</mechanism></mechanisms></stream:features>";
@@ -55,18 +66,6 @@
     "19853128\" from=\""
     "110cc78f78d7032cc7bf2c6e14c1fa7d@clouddevices.gserviceaccount.com\" "
     "id=\"3\" type=\"result\"/>";
-constexpr char kStartStreamMessage[] =
-    "<stream:stream to='clouddevices.gserviceaccount.com' "
-    "xmlns:stream='http://etherx.jabber.org/streams' xml:lang='*' "
-    "version='1.0' xmlns='jabber:client'>";
-constexpr char kStartTlsMessage[] =
-    "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>";
-constexpr char kAuthenticationMessage[] =
-    "<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='X-OAUTH2' "
-    "auth:service='oauth2' auth:allow-non-google-login='true' "
-    "auth:client-uses-full-bind-result='true' "
-    "xmlns:auth='http://www.google.com/talk/protocol/auth'>"
-    "AEFjY291bnRATmFtZQBBY2Nlc3NUb2tlbg==</auth>";
 constexpr char kBindMessage[] =
     "<iq id='1' type='set'><bind "
     "xmlns='urn:ietf:params:xml:ns:xmpp-bind'/></iq>";
@@ -84,10 +83,6 @@
  public:
   explicit FakeStream(TaskRunner* task_runner) : task_runner_{task_runner} {}
 
-  bool FlushBlocking(ErrorPtr* error) override { return true; }
-
-  bool CloseBlocking(ErrorPtr* error) override { return true; }
-
   void CancelPendingAsyncOperations() override {}
 
   void ExpectWritePacketString(base::TimeDelta, const std::string& data) {
@@ -134,35 +129,55 @@
 
 class FakeXmppChannel : public XmppChannel {
  public:
-  explicit FakeXmppChannel(TaskRunner* task_runner)
-      : XmppChannel{kAccountName, kAccessToken, task_runner, nullptr},
-        fake_stream_{task_runner} {}
+  explicit FakeXmppChannel(TaskRunner* task_runner, weave::Network* network)
+      : XmppChannel{kAccountName, kAccessToken, task_runner, network},
+        stream_{new FakeStream{task_runner_}},
+        fake_stream_{stream_.get()} {}
+
+  void Connect(
+      const base::Callback<void(std::unique_ptr<weave::Stream>)>& callback) {
+    callback.Run(std::move(stream_));
+  }
 
   XmppState state() const { return state_; }
   void set_state(XmppState state) { state_ = state; }
 
-  void Connect(const std::string& host,
-               uint16_t port,
-               const base::Closure& callback) override {
-    set_state(XmppState::kConnecting);
-    stream_ = &fake_stream_;
-    callback.Run();
-  }
-
   void SchedulePing(base::TimeDelta interval,
                     base::TimeDelta timeout) override {}
 
-  FakeStream fake_stream_;
+  void ExpectWritePacketString(base::TimeDelta delta, const std::string& data) {
+    fake_stream_->ExpectWritePacketString(delta, data);
+  }
+
+  void AddReadPacketString(base::TimeDelta delta, const std::string& data) {
+    fake_stream_->AddReadPacketString(delta, data);
+  }
+
+  std::unique_ptr<FakeStream> stream_;
+  FakeStream* fake_stream_{nullptr};
+};
+
+class MockNetwork : public weave::test::MockNetwork {
+ public:
+  MockNetwork() {
+    EXPECT_CALL(*this, AddOnConnectionChangedCallback(_))
+        .WillRepeatedly(Return());
+  }
 };
 
 class XmppChannelTest : public ::testing::Test {
  protected:
+  XmppChannelTest() {
+    EXPECT_CALL(network_, OpenSslSocket("talk.google.com", 5223, _, _))
+        .WillOnce(
+            WithArgs<2>(Invoke(&xmpp_client_, &FakeXmppChannel::Connect)));
+  }
+
   void StartStream() {
-    xmpp_client_.fake_stream_.ExpectWritePacketString({}, kStartStreamMessage);
-    xmpp_client_.fake_stream_.AddReadPacketString({}, kStartStreamResponse);
-    xmpp_client_.fake_stream_.ExpectWritePacketString({}, kStartTlsMessage);
+    xmpp_client_.ExpectWritePacketString({}, kStartStreamMessage);
+    xmpp_client_.AddReadPacketString({}, kStartStreamResponse);
     xmpp_client_.Start(nullptr);
-    RunUntil(XmppChannel::XmppState::kTlsStarted);
+    RunUntil(XmppChannel::XmppState::kConnected);
   }
 
   void StartWithState(XmppChannel::XmppState state) {
@@ -177,12 +192,13 @@
   }
 
   StrictMock<test::MockTaskRunner> task_runner_;
-  FakeXmppChannel xmpp_client_{&task_runner_};
+  StrictMock<MockNetwork> network_;
+  FakeXmppChannel xmpp_client_{&task_runner_, &network_};
 };
 
 TEST_F(XmppChannelTest, StartStream) {
   EXPECT_EQ(XmppChannel::XmppState::kNotStarted, xmpp_client_.state());
-  xmpp_client_.fake_stream_.ExpectWritePacketString({}, kStartStreamMessage);
+  xmpp_client_.ExpectWritePacketString({}, kStartStreamMessage);
   xmpp_client_.Start(nullptr);
   RunUntil(XmppChannel::XmppState::kConnected);
 }
@@ -191,48 +207,46 @@
   StartStream();
 }
 
-TEST_F(XmppChannelTest, HandleTLSCompleted) {
-  StartWithState(XmppChannel::XmppState::kTlsCompleted);
-  xmpp_client_.fake_stream_.AddReadPacketString({}, kTlsStreamResponse);
-  xmpp_client_.fake_stream_.ExpectWritePacketString({}, kAuthenticationMessage);
+TEST_F(XmppChannelTest, HandleConnected) {
+  StartWithState(XmppChannel::XmppState::kConnected);
+  xmpp_client_.AddReadPacketString({}, kConnectedResponse);
+  xmpp_client_.ExpectWritePacketString({}, kAuthenticationMessage);
   RunUntil(XmppChannel::XmppState::kAuthenticationStarted);
 }
 
 TEST_F(XmppChannelTest, HandleAuthenticationSucceededResponse) {
   StartWithState(XmppChannel::XmppState::kAuthenticationStarted);
-  xmpp_client_.fake_stream_.AddReadPacketString(
-      {}, kAuthenticationSucceededResponse);
-  xmpp_client_.fake_stream_.ExpectWritePacketString({}, kStartStreamMessage);
+  xmpp_client_.AddReadPacketString({}, kAuthenticationSucceededResponse);
+  xmpp_client_.ExpectWritePacketString({}, kStartStreamMessage);
   RunUntil(XmppChannel::XmppState::kStreamRestartedPostAuthentication);
 }
 
 TEST_F(XmppChannelTest, HandleAuthenticationFailedResponse) {
   StartWithState(XmppChannel::XmppState::kAuthenticationStarted);
-  xmpp_client_.fake_stream_.AddReadPacketString({},
-                                                kAuthenticationFailedResponse);
+  xmpp_client_.AddReadPacketString({}, kAuthenticationFailedResponse);
   RunUntil(XmppChannel::XmppState::kAuthenticationFailed);
 }
 
 TEST_F(XmppChannelTest, HandleStreamRestartedResponse) {
   StartWithState(XmppChannel::XmppState::kStreamRestartedPostAuthentication);
-  xmpp_client_.fake_stream_.AddReadPacketString({}, kRestartStreamResponse);
-  xmpp_client_.fake_stream_.ExpectWritePacketString({}, kBindMessage);
+  xmpp_client_.AddReadPacketString({}, kRestartStreamResponse);
+  xmpp_client_.ExpectWritePacketString({}, kBindMessage);
   RunUntil(XmppChannel::XmppState::kBindSent);
   EXPECT_TRUE(xmpp_client_.jid().empty());
 
-  xmpp_client_.fake_stream_.AddReadPacketString({}, kBindResponse);
-  xmpp_client_.fake_stream_.ExpectWritePacketString({}, kSessionMessage);
+  xmpp_client_.AddReadPacketString({}, kBindResponse);
+  xmpp_client_.ExpectWritePacketString({}, kSessionMessage);
   RunUntil(XmppChannel::XmppState::kSessionStarted);
   EXPECT_EQ(
       "110cc78f78d7032cc7bf2c6e14c1fa7d@clouddevices.gserviceaccount.com"
       "/19853128",
       xmpp_client_.jid());
 
-  xmpp_client_.fake_stream_.AddReadPacketString({}, kSessionResponse);
-  xmpp_client_.fake_stream_.ExpectWritePacketString({}, kSubscribeMessage);
+  xmpp_client_.AddReadPacketString({}, kSessionResponse);
+  xmpp_client_.ExpectWritePacketString({}, kSubscribeMessage);
   RunUntil(XmppChannel::XmppState::kSubscribeStarted);
 
-  xmpp_client_.fake_stream_.AddReadPacketString({}, kSubscribedResponse);
+  xmpp_client_.AddReadPacketString({}, kSubscribedResponse);
   RunUntil(XmppChannel::XmppState::kSubscribed);
 }