Use single callback for replies to async operations Single callback simplifies copying of callbacks and makes control flow more obvious. BUG:24267885 Change-Id: I489e7158e2bb1adf8c9c3966a0859fa024a57db2 Reviewed-on: https://weave-review.googlesource.com/1302 Reviewed-by: Vitaly Buka <vitalybuka@google.com>
diff --git a/libweave/examples/ubuntu/curl_http_client.cc b/libweave/examples/ubuntu/curl_http_client.cc index c058a9b..653cbb3 100644 --- a/libweave/examples/ubuntu/curl_http_client.cc +++ b/libweave/examples/ubuntu/curl_http_client.cc
@@ -35,18 +35,11 @@ CurlHttpClient::CurlHttpClient(provider::TaskRunner* task_runner) : task_runner_{task_runner} {} -void CurlHttpClient::PostError(const ErrorCallback& error_callback, - ErrorPtr error) { - task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(error_callback, base::Passed(&error)), {}); -} - void CurlHttpClient::SendRequest(Method method, const std::string& url, const Headers& headers, const std::string& data, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) { + const SendRequestCallback& callback) { std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl{curl_easy_init(), &curl_easy_cleanup}; CHECK(curl); @@ -96,7 +89,8 @@ if (res != CURLE_OK) { Error::AddTo(&error, FROM_HERE, "curl", "curl_easy_perform_error", curl_easy_strerror(res)); - return PostError(error_callback, std::move(error)); + return task_runner_->PostDelayedTask( + FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); } const std::string kContentType = "\r\nContent-Type:"; @@ -104,7 +98,8 @@ if (pos == std::string::npos) { Error::AddTo(&error, FROM_HERE, "curl", "no_content_header", "Content-Type header is missing"); - return PostError(error_callback, std::move(error)); + return task_runner_->PostDelayedTask( + FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); } pos += kContentType.size(); auto pos_end = response->content_type.find("\r\n", pos); @@ -118,15 +113,7 @@ &response->status)); task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(&CurlHttpClient::RunSuccessCallback, - weak_ptr_factory_.GetWeakPtr(), success_callback, - base::Passed(&response)), - {}); -} - -void CurlHttpClient::RunSuccessCallback(const SuccessCallback& success_callback, - std::unique_ptr<Response> response) { - success_callback.Run(*response); + FROM_HERE, base::Bind(callback, base::Passed(&response), nullptr), {}); } } // namespace examples
diff --git a/libweave/examples/ubuntu/curl_http_client.h b/libweave/examples/ubuntu/curl_http_client.h index a5090ff..6bbb3f4 100644 --- a/libweave/examples/ubuntu/curl_http_client.h +++ b/libweave/examples/ubuntu/curl_http_client.h
@@ -28,14 +28,9 @@ const std::string& url, const Headers& headers, const std::string& data, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) override; + const SendRequestCallback& callback) override; private: - void PostError(const ErrorCallback& error_callback, ErrorPtr error); - - void RunSuccessCallback(const SuccessCallback& success_callback, - std::unique_ptr<Response> response); provider::TaskRunner* task_runner_{nullptr}; base::WeakPtrFactory<CurlHttpClient> weak_ptr_factory_{this};
diff --git a/libweave/examples/ubuntu/event_http_client.cc b/libweave/examples/ubuntu/event_http_client.cc index 8341a07..637ae2f 100644 --- a/libweave/examples/ubuntu/event_http_client.cc +++ b/libweave/examples/ubuntu/event_http_client.cc
@@ -60,8 +60,7 @@ TaskRunner* task_runner_; std::unique_ptr<evhttp_uri, EventDeleter> http_uri_; std::unique_ptr<evhttp_connection, EventDeleter> evcon_; - HttpClient::SuccessCallback success_callback_; - ErrorCallback error_callback_; + HttpClient::SendRequestCallback callback_; }; void RequestDoneCallback(evhttp_request* req, void* ctx) { @@ -74,7 +73,7 @@ "request failed: %s", evutil_socket_error_to_string(err)); state->task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(state->error_callback_, base::Passed(&error)), + FROM_HERE, base::Bind(state->callback_, nullptr, base::Passed(&error)), {}); return; } @@ -86,7 +85,8 @@ auto n = evbuffer_remove(buffer, &response->data[0], length); CHECK_EQ(n, int(length)); state->task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(state->success_callback_, *response), {}); + FROM_HERE, base::Bind(state->callback_, base::Passed(&response), nullptr), + {}); } } // namespace @@ -98,8 +98,7 @@ const std::string& url, const Headers& headers, const std::string& data, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) { + const SendRequestCallback& callback) { evhttp_cmd_type method_id; CHECK(weave::StringToEnum(weave::EnumToString(method), &method_id)); std::unique_ptr<evhttp_uri, EventDeleter> http_uri{ @@ -129,7 +128,7 @@ std::unique_ptr<evhttp_request, EventDeleter> req{evhttp_request_new( &RequestDoneCallback, new EventRequestState{task_runner_, std::move(http_uri), std::move(conn), - success_callback, error_callback})}; + callback})}; CHECK(req); auto output_headers = evhttp_request_get_output_headers(req.get()); evhttp_add_header(output_headers, "Host", host); @@ -150,7 +149,7 @@ "request failed: %s %s", EnumToString(method).c_str(), url.c_str()); task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(error_callback, base::Passed(&error)), {}); + FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); } } // namespace examples
diff --git a/libweave/examples/ubuntu/event_http_client.h b/libweave/examples/ubuntu/event_http_client.h index 457e550..e564e8a 100644 --- a/libweave/examples/ubuntu/event_http_client.h +++ b/libweave/examples/ubuntu/event_http_client.h
@@ -24,8 +24,7 @@ const std::string& url, const Headers& headers, const std::string& data, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) override; + const SendRequestCallback& callback) override; private: EventTaskRunner* task_runner_{nullptr};
diff --git a/libweave/examples/ubuntu/main.cc b/libweave/examples/ubuntu/main.cc index e762791..6324411 100644 --- a/libweave/examples/ubuntu/main.cc +++ b/libweave/examples/ubuntu/main.cc
@@ -221,12 +221,11 @@ base::WeakPtrFactory<CommandHandler> weak_ptr_factory_{this}; }; -void RegisterDeviceSuccess(weave::Device* device) { - LOG(INFO) << "Device registered: " << device->GetSettings().cloud_id; -} - -void RegisterDeviceError(weave::ErrorPtr error) { - LOG(ERROR) << "Fail to register device: " << error->GetMessage(); +void OnRegisterDeviceDone(weave::Device* device, weave::ErrorPtr error) { + if (error) + LOG(ERROR) << "Fail to register device: " << error->GetMessage(); + else + LOG(INFO) << "Device registered: " << device->GetSettings().cloud_id; } } // namespace @@ -280,8 +279,7 @@ if (!registration_ticket.empty()) { device->Register(registration_ticket, - base::Bind(&RegisterDeviceSuccess, device.get()), - base::Bind(&RegisterDeviceError)); + base::Bind(&OnRegisterDeviceDone, device.get())); } CommandHandler handler(device.get(), &task_runner);
diff --git a/libweave/examples/ubuntu/netlink_network.cc b/libweave/examples/ubuntu/netlink_network.cc index 24d0996..60807a4 100644 --- a/libweave/examples/ubuntu/netlink_network.cc +++ b/libweave/examples/ubuntu/netlink_network.cc
@@ -91,23 +91,22 @@ return network_state_; } -void NetlinkNetworkImpl::OpenSslSocket( - const std::string& host, - uint16_t port, - const OpenSslSocketSuccessCallback& success_callback, - const ErrorCallback& error_callback) { +void NetlinkNetworkImpl::OpenSslSocket(const std::string& host, + uint16_t port, + const OpenSslSocketCallback& callback) { // Connect to SSL port instead of upgrading to TLS. std::unique_ptr<SSLStream> tls_stream{new SSLStream{task_runner_}}; if (tls_stream->Init(host, port)) { task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(success_callback, base::Passed(&tls_stream)), {}); + FROM_HERE, base::Bind(callback, base::Passed(&tls_stream), nullptr), + {}); } else { ErrorPtr error; Error::AddTo(&error, FROM_HERE, "tls", "tls_init_failed", "Failed to initialize TLS stream."); task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(error_callback, base::Passed(&error)), {}); + FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); } }
diff --git a/libweave/examples/ubuntu/netlink_network.h b/libweave/examples/ubuntu/netlink_network.h index 3f5f51a..af8bf13 100644 --- a/libweave/examples/ubuntu/netlink_network.h +++ b/libweave/examples/ubuntu/netlink_network.h
@@ -29,8 +29,7 @@ State GetConnectionState() const override; void OpenSslSocket(const std::string& host, uint16_t port, - const OpenSslSocketSuccessCallback& success_callback, - const ErrorCallback& error_callback) override; + const OpenSslSocketCallback& callback) override; private: class Deleter {
diff --git a/libweave/examples/ubuntu/network_manager.cc b/libweave/examples/ubuntu/network_manager.cc index ff54f0b..e6dd9d8 100644 --- a/libweave/examples/ubuntu/network_manager.cc +++ b/libweave/examples/ubuntu/network_manager.cc
@@ -61,8 +61,7 @@ const std::string& passphrase, int pid, base::Time until, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) { + const DoneCallback& callback) { if (pid) { int status = 0; if (pid == waitpid(pid, &status, WNOWAIT)) { @@ -79,7 +78,8 @@ close(sockf_d); if (ssid == essid) - return task_runner_->PostDelayedTask(FROM_HERE, success_callback, {}); + return task_runner_->PostDelayedTask(FROM_HERE, + base::Bind(callback, nullptr), {}); pid = 0; // Try again. } } @@ -94,34 +94,32 @@ Error::AddTo(&error, FROM_HERE, "wifi", "timeout", "Timeout connecting to WiFI network."); task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(error_callback, base::Passed(&error)), {}); + FROM_HERE, base::Bind(callback, base::Passed(&error)), {}); return; } task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(&NetworkImpl::TryToConnect, - weak_ptr_factory_.GetWeakPtr(), ssid, passphrase, - pid, until, success_callback, error_callback), + FROM_HERE, + base::Bind(&NetworkImpl::TryToConnect, weak_ptr_factory_.GetWeakPtr(), + ssid, passphrase, pid, until, callback), base::TimeDelta::FromSeconds(1)); } void NetworkImpl::Connect(const std::string& ssid, const std::string& passphrase, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) { + const DoneCallback& callback) { force_bootstrapping_ = false; CHECK(!hostapd_started_); if (hostapd_started_) { ErrorPtr error; Error::AddTo(&error, FROM_HERE, "wifi", "busy", "Running Access Point."); task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(error_callback, base::Passed(&error)), {}); + FROM_HERE, base::Bind(callback, base::Passed(&error)), {}); return; } TryToConnect(ssid, passphrase, 0, - base::Time::Now() + base::TimeDelta::FromMinutes(1), - success_callback, error_callback); + base::Time::Now() + base::TimeDelta::FromMinutes(1), callback); } void NetworkImpl::UpdateNetworkState() { @@ -200,23 +198,22 @@ return std::system("nmcli dev | grep ^wlan0") == 0; } -void NetworkImpl::OpenSslSocket( - const std::string& host, - uint16_t port, - const OpenSslSocketSuccessCallback& success_callback, - const ErrorCallback& error_callback) { +void NetworkImpl::OpenSslSocket(const std::string& host, + uint16_t port, + const OpenSslSocketCallback& callback) { // Connect to SSL port instead of upgrading to TLS. std::unique_ptr<SSLStream> tls_stream{new SSLStream{task_runner_}}; if (tls_stream->Init(host, port)) { task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(success_callback, base::Passed(&tls_stream)), {}); + FROM_HERE, base::Bind(callback, base::Passed(&tls_stream), nullptr), + {}); } else { ErrorPtr error; Error::AddTo(&error, FROM_HERE, "tls", "tls_init_failed", "Failed to initialize TLS stream."); task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(error_callback, base::Passed(&error)), {}); + FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); } }
diff --git a/libweave/examples/ubuntu/network_manager.h b/libweave/examples/ubuntu/network_manager.h index b8e589d..164d419 100644 --- a/libweave/examples/ubuntu/network_manager.h +++ b/libweave/examples/ubuntu/network_manager.h
@@ -35,14 +35,12 @@ State GetConnectionState() const override; void OpenSslSocket(const std::string& host, uint16_t port, - const OpenSslSocketSuccessCallback& success_callback, - const ErrorCallback& error_callback) override; + const OpenSslSocketCallback& callback) override; // Wifi implementation. void Connect(const std::string& ssid, const std::string& passphrase, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) override; + const DoneCallback& callback) override; void StartAccessPoint(const std::string& ssid) override; void StopAccessPoint() override; @@ -53,8 +51,7 @@ const std::string& passphrase, int pid, base::Time until, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback); + const DoneCallback& callback); void UpdateNetworkState(); bool force_bootstrapping_{false};
diff --git a/libweave/examples/ubuntu/ssl_stream.cc b/libweave/examples/ubuntu/ssl_stream.cc index e146d58..8b17358 100644 --- a/libweave/examples/ubuntu/ssl_stream.cc +++ b/libweave/examples/ubuntu/ssl_stream.cc
@@ -23,14 +23,13 @@ void SSLStream::Read(void* buffer, size_t size_to_read, - const ReadSuccessCallback& success_callback, - const ErrorCallback& error_callback) { + const ReadCallback& callback) { int res = SSL_read(ssl_.get(), buffer, size_to_read); if (res > 0) { task_runner_->PostDelayedTask( FROM_HERE, base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(), - base::Bind(success_callback, res)), + base::Bind(callback, res, nullptr)), {}); return; } @@ -39,9 +38,8 @@ if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { task_runner_->PostDelayedTask( - FROM_HERE, - base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer, - size_to_read, success_callback, error_callback), + FROM_HERE, base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), + buffer, size_to_read, callback), base::TimeDelta::FromSeconds(1)); return; } @@ -52,15 +50,14 @@ task_runner_->PostDelayedTask( FROM_HERE, base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(), - base::Bind(error_callback, base::Passed(&weave_error))), + base::Bind(callback, 0, base::Passed(&weave_error))), {}); return; } void SSLStream::Write(const void* buffer, size_t size_to_write, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) { + const WriteCallback& callback) { int res = SSL_write(ssl_.get(), buffer, size_to_write); if (res > 0) { buffer = static_cast<const char*>(buffer) + res; @@ -69,15 +66,14 @@ task_runner_->PostDelayedTask( FROM_HERE, base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(), - success_callback), + base::Bind(callback, nullptr)), {}); return; } task_runner_->PostDelayedTask( - FROM_HERE, - base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer, - size_to_write, success_callback, error_callback), + FROM_HERE, base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), + buffer, size_to_write, callback), base::TimeDelta::FromSeconds(1)); return; @@ -87,9 +83,8 @@ if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { task_runner_->PostDelayedTask( - FROM_HERE, - base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer, - size_to_write, success_callback, error_callback), + FROM_HERE, base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), + buffer, size_to_write, callback), base::TimeDelta::FromSeconds(1)); return; } @@ -100,7 +95,7 @@ task_runner_->PostDelayedTask( FROM_HERE, base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(), - base::Bind(error_callback, base::Passed(&weave_error))), + base::Bind(callback, base::Passed(&weave_error))), {}); return; }
diff --git a/libweave/examples/ubuntu/ssl_stream.h b/libweave/examples/ubuntu/ssl_stream.h index ac0d76a..859cbf9 100644 --- a/libweave/examples/ubuntu/ssl_stream.h +++ b/libweave/examples/ubuntu/ssl_stream.h
@@ -26,13 +26,11 @@ void Read(void* buffer, size_t size_to_read, - const ReadSuccessCallback& success_callback, - const ErrorCallback& error_callback) override; + const ReadCallback& callback) override; void Write(const void* buffer, size_t size_to_write, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) override; + const WriteCallback& callback) override; void CancelPendingOperations() override;
diff --git a/libweave/include/weave/device.h b/libweave/include/weave/device.h index 7fbe248..09dab91 100644 --- a/libweave/include/weave/device.h +++ b/libweave/include/weave/device.h
@@ -121,8 +121,7 @@ // Registers the device. // This is testing method and should not be used by applications. virtual void Register(const std::string& ticket_id, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) = 0; + const DoneCallback& callback) = 0; // Handler should display pin code to the user. using PairingBeginCallback =
diff --git a/libweave/include/weave/error.h b/libweave/include/weave/error.h index 83b33b4..0869b71 100644 --- a/libweave/include/weave/error.h +++ b/libweave/include/weave/error.h
@@ -126,8 +126,13 @@ DISALLOW_COPY_AND_ASSIGN(Error); }; -using SuccessCallback = base::Closure; -using ErrorCallback = base::Callback<void(ErrorPtr error)>; +// Default callback type for async operations. +// Function having this callback as argument should call the callback exactly +// one time. +// Successfully completed operation should run callback with |error| set to +// null. Failed operation should run callback with |error| containing error +// details. +using DoneCallback = base::Callback<void(ErrorPtr error)>; } // namespace weave
diff --git a/libweave/include/weave/provider/http_client.h b/libweave/include/weave/provider/http_client.h index 24b4588..3c442d9 100644 --- a/libweave/include/weave/provider/http_client.h +++ b/libweave/include/weave/provider/http_client.h
@@ -34,14 +34,14 @@ }; using Headers = std::vector<std::pair<std::string, std::string>>; - using SuccessCallback = base::Callback<void(const Response&)>; + using SendRequestCallback = + base::Callback<void(std::unique_ptr<Response> response, ErrorPtr error)>; virtual void SendRequest(Method method, const std::string& url, const Headers& headers, const std::string& data, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) = 0; + const SendRequestCallback& callback) = 0; protected: virtual ~HttpClient() = default;
diff --git a/libweave/include/weave/provider/network.h b/libweave/include/weave/provider/network.h index cb3730e..4c0d5dd 100644 --- a/libweave/include/weave/provider/network.h +++ b/libweave/include/weave/provider/network.h
@@ -29,8 +29,8 @@ using ConnectionChangedCallback = base::Closure; // Callback type for OpenSslSocket. - using OpenSslSocketSuccessCallback = - base::Callback<void(std::unique_ptr<Stream> stream)>; + using OpenSslSocketCallback = + base::Callback<void(std::unique_ptr<Stream> stream, ErrorPtr error)>; // Subscribes to notification about changes in network connectivity. Changes // may include but not limited: interface up or down, new IP was assigned, @@ -42,11 +42,9 @@ virtual State GetConnectionState() const = 0; // Opens bidirectional sockets and returns attached stream. - virtual void OpenSslSocket( - const std::string& host, - uint16_t port, - const OpenSslSocketSuccessCallback& success_callback, - const ErrorCallback& error_callback) = 0; + virtual void OpenSslSocket(const std::string& host, + uint16_t port, + const OpenSslSocketCallback& callback) = 0; protected: virtual ~Network() = default;
diff --git a/libweave/include/weave/provider/test/mock_http_client.h b/libweave/include/weave/provider/test/mock_http_client.h index 72cb9b8..85ac154 100644 --- a/libweave/include/weave/provider/test/mock_http_client.h +++ b/libweave/include/weave/provider/test/mock_http_client.h
@@ -27,13 +27,12 @@ public: ~MockHttpClient() override = default; - MOCK_METHOD6(SendRequest, + MOCK_METHOD5(SendRequest, void(Method, const std::string&, const Headers&, const std::string&, - const SuccessCallback&, - const ErrorCallback&)); + const SendRequestCallback&)); }; } // namespace test
diff --git a/libweave/include/weave/provider/test/mock_network.h b/libweave/include/weave/provider/test/mock_network.h index e38dde2..b2811dc 100644 --- a/libweave/include/weave/provider/test/mock_network.h +++ b/libweave/include/weave/provider/test/mock_network.h
@@ -20,11 +20,10 @@ MOCK_METHOD1(AddConnectionChangedCallback, void(const ConnectionChangedCallback&)); MOCK_CONST_METHOD0(GetConnectionState, State()); - MOCK_METHOD4(OpenSslSocket, + MOCK_METHOD3(OpenSslSocket, void(const std::string&, uint16_t, - const OpenSslSocketSuccessCallback&, - const ErrorCallback&)); + const OpenSslSocketCallback&)); }; } // namespace test
diff --git a/libweave/include/weave/provider/test/mock_wifi.h b/libweave/include/weave/provider/test/mock_wifi.h index 6c53d9a..5798872 100644 --- a/libweave/include/weave/provider/test/mock_wifi.h +++ b/libweave/include/weave/provider/test/mock_wifi.h
@@ -17,11 +17,10 @@ class MockWifi : public Wifi { public: - MOCK_METHOD4(Connect, + MOCK_METHOD3(Connect, void(const std::string&, const std::string&, - const base::Closure&, - const ErrorCallback&)); + const DoneCallback&)); MOCK_METHOD1(StartAccessPoint, void(const std::string&)); MOCK_METHOD0(StopAccessPoint, void()); };
diff --git a/libweave/include/weave/provider/wifi.h b/libweave/include/weave/provider/wifi.h index 51f370c..aaf4609 100644 --- a/libweave/include/weave/provider/wifi.h +++ b/libweave/include/weave/provider/wifi.h
@@ -20,8 +20,7 @@ // should post either of callbacks. virtual void Connect(const std::string& ssid, const std::string& passphrase, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) = 0; + const DoneCallback& callback) = 0; // Starts WiFi access point for wifi setup. virtual void StartAccessPoint(const std::string& ssid) = 0;
diff --git a/libweave/include/weave/stream.h b/libweave/include/weave/stream.h index 9d4d6fd..ea8af23 100644 --- a/libweave/include/weave/stream.h +++ b/libweave/include/weave/stream.h
@@ -18,15 +18,14 @@ virtual ~InputStream() = default; // Callback type for Read. - using ReadSuccessCallback = base::Callback<void(size_t size)>; + using ReadCallback = base::Callback<void(size_t size, ErrorPtr error)>; - // Implementation should return immediately and post either success_callback - // or error_callback. Caller guarantees that buffet is alive until either of - // callback is called. + // Implementation should return immediately and post callback after + // completing operation. Caller guarantees that buffet is alive until callback + // is called. virtual void Read(void* buffer, size_t size_to_read, - const ReadSuccessCallback& success_callback, - const ErrorCallback& error_callback) = 0; + const ReadCallback& callback) = 0; }; // Interface for async input streaming. @@ -34,14 +33,15 @@ public: virtual ~OutputStream() = default; - // Implementation should return immediately and post either success_callback - // or error_callback. Caller guarantees that buffet is alive until either of - // callback is called. + using WriteCallback = base::Callback<void(ErrorPtr error)>; + + // Implementation should return immediately and post callback after + // completing operation. Caller guarantees that buffet is alive until either + // of callback is called. // Success callback must be called only after all data is written. virtual void Write(const void* buffer, size_t size_to_write, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) = 0; + const WriteCallback& callback) = 0; }; // Interface for async bi-directional streaming.
diff --git a/libweave/include/weave/test/fake_stream.h b/libweave/include/weave/test/fake_stream.h index 8abb491..e5e0b57 100644 --- a/libweave/include/weave/test/fake_stream.h +++ b/libweave/include/weave/test/fake_stream.h
@@ -31,12 +31,10 @@ void CancelPendingOperations() override; void Read(void* buffer, size_t size_to_read, - const ReadSuccessCallback& success_callback, - const ErrorCallback& error_callback) override; + const ReadCallback& callback) override; void Write(const void* buffer, size_t size_to_write, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) override; + const WriteCallback& callback) override; private: provider::TaskRunner* task_runner_{nullptr};
diff --git a/libweave/include/weave/test/mock_device.h b/libweave/include/weave/test/mock_device.h index 0d8a6da..c0a75bc 100644 --- a/libweave/include/weave/test/mock_device.h +++ b/libweave/include/weave/test/mock_device.h
@@ -44,10 +44,9 @@ MOCK_CONST_METHOD0(GetGcdState, GcdState()); MOCK_METHOD1(AddGcdStateChangedCallback, void(const GcdStateChangedCallback& callback)); - MOCK_METHOD3(Register, + MOCK_METHOD2(Register, void(const std::string& ticket_id, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback)); + const DoneCallback& callback)); MOCK_METHOD2(AddPairingChangedCallbacks, void(const PairingBeginCallback& begin_callback, const PairingEndCallback& end_callback));
diff --git a/libweave/src/commands/cloud_command_proxy.cc b/libweave/src/commands/cloud_command_proxy.cc index 3826378..41dc8ee 100644 --- a/libweave/src/commands/cloud_command_proxy.cc +++ b/libweave/src/commands/cloud_command_proxy.cc
@@ -139,10 +139,8 @@ command_update_in_progress_ = true; cloud_command_updater_->UpdateCommand( command_instance_->GetID(), *update_queue_.front().second, - base::Bind(&CloudCommandProxy::OnUpdateCommandFinished, - weak_ptr_factory_.GetWeakPtr(), true), - base::Bind(&CloudCommandProxy::OnUpdateCommandFinished, - weak_ptr_factory_.GetWeakPtr(), false)); + base::Bind(&CloudCommandProxy::OnUpdateCommandDone, + weak_ptr_factory_.GetWeakPtr())); } void CloudCommandProxy::ResendCommandUpdate() { @@ -150,10 +148,10 @@ SendCommandUpdate(); } -void CloudCommandProxy::OnUpdateCommandFinished(bool success) { +void CloudCommandProxy::OnUpdateCommandDone(ErrorPtr error) { command_update_in_progress_ = false; - cloud_backoff_entry_->InformOfRequest(success); - if (success) { + cloud_backoff_entry_->InformOfRequest(!error); + if (!error) { // Remove the succeeded update from the queue. update_queue_.pop_front(); }
diff --git a/libweave/src/commands/cloud_command_proxy.h b/libweave/src/commands/cloud_command_proxy.h index 86dc5d3..b2ef11a 100644 --- a/libweave/src/commands/cloud_command_proxy.h +++ b/libweave/src/commands/cloud_command_proxy.h
@@ -63,9 +63,7 @@ void ResendCommandUpdate(); // Callback invoked by the asynchronous PATCH request to the server. - // Called both in a case of successfully updating server command resource - // and in case of an error, indicated by the |success| parameter. - void OnUpdateCommandFinished(bool success); + void OnUpdateCommandDone(ErrorPtr error); // Callback invoked by the device state change queue to notify of the // successful device state update. |update_id| is the ID of the state that
diff --git a/libweave/src/commands/cloud_command_proxy_unittest.cc b/libweave/src/commands/cloud_command_proxy_unittest.cc index 3c7692f..d9518e4 100644 --- a/libweave/src/commands/cloud_command_proxy_unittest.cc +++ b/libweave/src/commands/cloud_command_proxy_unittest.cc
@@ -38,11 +38,10 @@ class MockCloudCommandUpdateInterface : public CloudCommandUpdateInterface { public: - MOCK_METHOD4(UpdateCommand, + MOCK_METHOD3(UpdateCommand, void(const std::string&, const base::DictionaryValue&, - const base::Closure&, - const base::Closure&)); + const DoneCallback&)); }; // Test back-off entry that uses the test clock. @@ -147,7 +146,7 @@ TEST_F(CloudCommandProxyTest, ImmediateUpdate) { const char expected[] = "{'state':'done'}"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _)); + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _)); command_instance_->Complete({}, nullptr); task_runner_.RunOnce(); } @@ -161,7 +160,7 @@ callbacks_.Notify(19); // Now we should get the update... const char expected[] = "{'state':'done'}"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _)); + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _)); callbacks_.Notify(20); } @@ -170,14 +169,14 @@ // state=inProgress // progress={...} // The first state update is sent immediately, the second should be delayed. - base::Closure on_success; + DoneCallback callback; EXPECT_CALL( cloud_updater_, UpdateCommand( kCmdID, - MatchJson("{'state':'inProgress', 'progress':{'status':'ready'}}"), _, + MatchJson("{'state':'inProgress', 'progress':{'status':'ready'}}"), _)) - .WillOnce(SaveArg<2>(&on_success)); + .WillOnce(SaveArg<2>(&callback)); EXPECT_TRUE(command_instance_->SetProgress( *CreateDictionaryValue("{'status': 'ready'}"), nullptr)); @@ -200,34 +199,35 @@ 'progress': {'status':'ready'}, 'state':'inProgress' })"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _)); + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _)); callbacks_.Notify(20); } TEST_F(CloudCommandProxyTest, RetryFailed) { - base::Closure on_success; - base::Closure on_error; + DoneCallback callback; const char expect[] = "{'state':'inProgress', 'progress': {'status': 'ready'}}"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect), _, _)) + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect), _)) .Times(3) - .WillRepeatedly(DoAll(SaveArg<2>(&on_success), SaveArg<3>(&on_error))); + .WillRepeatedly(SaveArg<2>(&callback)); auto started = task_runner_.GetClock()->Now(); EXPECT_TRUE(command_instance_->SetProgress( *CreateDictionaryValue("{'status': 'ready'}"), nullptr)); task_runner_.Run(); - on_error.Run(); + ErrorPtr error; + Error::AddTo(&error, FROM_HERE, "TEST", "TEST", "TEST"); + callback.Run(error->Clone()); task_runner_.Run(); EXPECT_GE(task_runner_.GetClock()->Now() - started, base::TimeDelta::FromSecondsD(0.9)); - on_error.Run(); + callback.Run(error->Clone()); task_runner_.Run(); EXPECT_GE(task_runner_.GetClock()->Now() - started, base::TimeDelta::FromSecondsD(2.9)); - on_success.Run(); + callback.Run(nullptr); task_runner_.Run(); EXPECT_GE(task_runner_.GetClock()->Now() - started, base::TimeDelta::FromSecondsD(2.9)); @@ -244,20 +244,20 @@ command_instance_->Complete({}, nullptr); // Device state #20 updated. - base::Closure on_success; + DoneCallback callback; const char expect1[] = R"({ 'progress': {'status':'ready'}, 'state':'inProgress' })"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect1), _, _)) - .WillOnce(SaveArg<2>(&on_success)); + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect1), _)) + .WillOnce(SaveArg<2>(&callback)); callbacks_.Notify(20); - on_success.Run(); + callback.Run(nullptr); // Device state #21 updated. const char expect2[] = "{'progress': {'status':'busy'}}"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect2), _, _)) - .WillOnce(SaveArg<2>(&on_success)); + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect2), _)) + .WillOnce(SaveArg<2>(&callback)); callbacks_.Notify(21); // Device state #22 updated. Nothing happens here since the previous command @@ -267,9 +267,9 @@ // Now the command update is complete, send out the patch that happened after // the state #22 was updated. const char expect3[] = "{'state': 'done'}"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect3), _, _)) - .WillOnce(SaveArg<2>(&on_success)); - on_success.Run(); + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect3), _)) + .WillOnce(SaveArg<2>(&callback)); + callback.Run(nullptr); } TEST_F(CloudCommandProxyTest, CombineSomeStates) { @@ -283,22 +283,22 @@ command_instance_->Complete({}, nullptr); // Device state 20-21 updated. - base::Closure on_success; + DoneCallback callback; const char expect1[] = R"({ 'progress': {'status':'busy'}, 'state':'inProgress' })"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect1), _, _)) - .WillOnce(SaveArg<2>(&on_success)); + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect1), _)) + .WillOnce(SaveArg<2>(&callback)); callbacks_.Notify(21); - on_success.Run(); + callback.Run(nullptr); // Device state #22 updated. const char expect2[] = "{'state': 'done'}"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect2), _, _)) - .WillOnce(SaveArg<2>(&on_success)); + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect2), _)) + .WillOnce(SaveArg<2>(&callback)); callbacks_.Notify(22); - on_success.Run(); + callback.Run(nullptr); } TEST_F(CloudCommandProxyTest, CombineAllStates) { @@ -316,7 +316,7 @@ 'progress': {'status':'busy'}, 'state':'done' })"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _)); + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _)); callbacks_.Notify(30); } @@ -336,7 +336,7 @@ 'results': {'sum':30}, 'state':'done' })"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _)); + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _)); callbacks_.Notify(30); } @@ -352,7 +352,7 @@ // As soon as we change the command, the update to the server should be sent. const char expected[] = "{'state':'done'}"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _)); + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _)); command_instance_->Complete({}, nullptr); task_runner_.RunOnce(); } @@ -370,7 +370,7 @@ // Only when the state #20 is published we should update the command const char expected[] = "{'state':'done'}"; - EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _)); + EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _)); callbacks_.Notify(20); }
diff --git a/libweave/src/commands/cloud_command_update_interface.h b/libweave/src/commands/cloud_command_update_interface.h index 5bd5c11..71c2f18 100644 --- a/libweave/src/commands/cloud_command_update_interface.h +++ b/libweave/src/commands/cloud_command_update_interface.h
@@ -18,8 +18,7 @@ public: virtual void UpdateCommand(const std::string& command_id, const base::DictionaryValue& command_patch, - const base::Closure& on_success, - const base::Closure& on_error) = 0; + const DoneCallback& callback) = 0; protected: virtual ~CloudCommandUpdateInterface() = default;
diff --git a/libweave/src/device_manager.cc b/libweave/src/device_manager.cc index 9105352..54570f6 100644 --- a/libweave/src/device_manager.cc +++ b/libweave/src/device_manager.cc
@@ -152,9 +152,8 @@ } void DeviceManager::Register(const std::string& ticket_id, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) { - device_info_->RegisterDevice(ticket_id, success_callback, error_callback); + const DoneCallback& callback) { + device_info_->RegisterDevice(ticket_id, callback); } void DeviceManager::AddPairingChangedCallbacks(
diff --git a/libweave/src/device_manager.h b/libweave/src/device_manager.h index 16e0601..7bded15 100644 --- a/libweave/src/device_manager.h +++ b/libweave/src/device_manager.h
@@ -59,8 +59,7 @@ ErrorPtr* error) override; std::unique_ptr<base::DictionaryValue> GetState() const override; void Register(const std::string& ticket_id, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) override; + const DoneCallback& callback) override; GcdState GetGcdState() const override; void AddGcdStateChangedCallback( const GcdStateChangedCallback& callback) override;
diff --git a/libweave/src/device_registration_info.cc b/libweave/src/device_registration_info.cc index 6481481..e35b768 100644 --- a/libweave/src/device_registration_info.cc +++ b/libweave/src/device_registration_info.cc
@@ -98,17 +98,18 @@ return AppendQueryParams(result, params); } -void IgnoreCloudError(ErrorPtr) {} - void IgnoreCloudErrorWithCallback(const base::Closure& cb, ErrorPtr) { cb.Run(); } -void IgnoreCloudResult(const base::DictionaryValue&) {} +void IgnoreCloudError(ErrorPtr) {} -void IgnoreCloudResultWithCallback(const base::Closure& cb, - const base::DictionaryValue&) { - cb.Run(); +void IgnoreCloudResult(const base::DictionaryValue&, ErrorPtr error) {} + +void IgnoreCloudResultWithCallback(const DoneCallback& cb, + const base::DictionaryValue&, + ErrorPtr error) { + cb.Run(std::move(error)); } class RequestSender final { @@ -118,31 +119,28 @@ HttpClient* transport) : method_{method}, url_{url}, transport_{transport} {} - void Send(const HttpClient::SuccessCallback& success_callback, - const ErrorCallback& error_callback) { + void Send(const HttpClient::SendRequestCallback& callback) { static int debug_id = 0; ++debug_id; VLOG(1) << "Sending request. id:" << debug_id << " method:" << EnumToString(method_) << " url:" << url_; VLOG(2) << "Request data: " << data_; - auto on_success = [](int debug_id, - const HttpClient::SuccessCallback& success_callback, - const HttpClient::Response& response) { - VLOG(1) << "Request succeeded. id:" << debug_id << " status:" << - response.GetStatusCode(); - VLOG(2) << "Response data: " << response.GetData(); - success_callback.Run(response); - }; - auto on_error = [](int debug_id, const ErrorCallback& error_callback, - ErrorPtr error) { - VLOG(1) << "Request failed, id=" << debug_id - << ", reason: " << error->GetCode() - << ", message: " << error->GetMessage(); - error_callback.Run(std::move(error)); + auto on_done = []( + int debug_id, const HttpClient::SendRequestCallback& callback, + std::unique_ptr<HttpClient::Response> response, ErrorPtr error) { + if (error) { + VLOG(1) << "Request failed, id=" << debug_id + << ", reason: " << error->GetCode() + << ", message: " << error->GetMessage(); + return callback.Run({}, std::move(error)); + } + VLOG(1) << "Request succeeded. id:" << debug_id + << " status:" << response->GetStatusCode(); + VLOG(2) << "Response data: " << response->GetData(); + callback.Run(std::move(response), nullptr); }; transport_->SendRequest(method_, url_, GetFullHeaders(), data_, - base::Bind(on_success, debug_id, success_callback), - base::Bind(on_error, debug_id, error_callback)); + base::Bind(on_done, debug_id, callback)); } void SetAccessToken(const std::string& access_token) { @@ -304,7 +302,8 @@ return; // Assume we're in test task_runner_->PostDelayedTask( FROM_HERE, - base::Bind(&DeviceRegistrationInfo::ConnectToCloud, AsWeakPtr()), delay); + base::Bind(&DeviceRegistrationInfo::ConnectToCloud, AsWeakPtr(), nullptr), + delay); } bool DeviceRegistrationInfo::HaveRegistrationCredentials() const { @@ -350,16 +349,12 @@ return resp; } -void DeviceRegistrationInfo::RefreshAccessToken( - const base::Closure& success_callback, - const ErrorCallback& error_callback) { +void DeviceRegistrationInfo::RefreshAccessToken(const DoneCallback& callback) { LOG(INFO) << "Refreshing access token."; ErrorPtr error; - if (!VerifyRegistrationCredentials(&error)) { - error_callback.Run(std::move(error)); - return; - } + if (!VerifyRegistrationCredentials(&error)) + return callback.Run(std::move(error)); if (oauth2_backoff_entry_->ShouldRejectRequest()) { VLOG(1) << "RefreshToken request delayed for " @@ -367,19 +362,11 @@ << " due to backoff policy"; task_runner_->PostDelayedTask( FROM_HERE, base::Bind(&DeviceRegistrationInfo::RefreshAccessToken, - AsWeakPtr(), success_callback, error_callback), + AsWeakPtr(), callback), oauth2_backoff_entry_->GetTimeUntilRelease()); return; } - // Make a shared pointer to |success_callback| and |error_callback| since we - // are going to share these callbacks with both success and error callbacks - // for PostFormData() and if the callbacks have any move-only types, - // one of the copies will be bad. - auto shared_success_callback = - std::make_shared<base::Closure>(success_callback); - auto shared_error_callback = std::make_shared<ErrorCallback>(error_callback); - RequestSender sender{HttpClient::Method::kPost, GetOAuthURL("token"), http_client_}; sender.SetFormData({ @@ -388,25 +375,25 @@ {"client_secret", GetSettings().client_secret}, {"grant_type", "refresh_token"}, }); - sender.Send(base::Bind(&DeviceRegistrationInfo::OnRefreshAccessTokenSuccess, - weak_factory_.GetWeakPtr(), shared_success_callback, - shared_error_callback), - base::Bind(&DeviceRegistrationInfo::OnRefreshAccessTokenError, - weak_factory_.GetWeakPtr(), shared_success_callback, - shared_error_callback)); + sender.Send(base::Bind(&DeviceRegistrationInfo::OnRefreshAccessTokenDone, + weak_factory_.GetWeakPtr(), callback)); VLOG(1) << "Refresh access token request dispatched"; } -void DeviceRegistrationInfo::OnRefreshAccessTokenSuccess( - const std::shared_ptr<base::Closure>& success_callback, - const std::shared_ptr<ErrorCallback>& error_callback, - const HttpClient::Response& response) { +void DeviceRegistrationInfo::OnRefreshAccessTokenDone( + const DoneCallback& callback, + std::unique_ptr<HttpClient::Response> response, + ErrorPtr error) { + if (error) { + VLOG(1) << "Refresh access token failed"; + oauth2_backoff_entry_->InformOfRequest(false); + return RefreshAccessToken(callback); + } VLOG(1) << "Refresh access token request completed"; oauth2_backoff_entry_->InformOfRequest(true); - ErrorPtr error; - auto json = ParseOAuthResponse(response, &error); + auto json = ParseOAuthResponse(*response, &error); if (!json) - return error_callback->Run(std::move(error)); + return callback.Run(std::move(error)); int expires_in = 0; if (!json->GetString("access_token", &access_token_) || @@ -415,7 +402,7 @@ LOG(ERROR) << "Access token unavailable."; Error::AddTo(&error, FROM_HERE, kErrorDomainOAuth2, "unexpected_server_response", "Access token unavailable"); - return error_callback->Run(std::move(error)); + return callback.Run(std::move(error)); } access_token_expiration_ = base::Time::Now() + base::TimeDelta::FromSeconds(expires_in); @@ -428,16 +415,7 @@ // Now that we have a new access token, retry the connection. StartNotificationChannel(); } - success_callback->Run(); -} - -void DeviceRegistrationInfo::OnRefreshAccessTokenError( - const std::shared_ptr<base::Closure>& success_callback, - const std::shared_ptr<ErrorCallback>& error_callback, - ErrorPtr error) { - VLOG(1) << "Refresh access token failed"; - oauth2_backoff_entry_->InformOfRequest(false); - RefreshAccessToken(*success_callback, *error_callback); + callback.Run(nullptr); } void DeviceRegistrationInfo::StartNotificationChannel() { @@ -522,45 +500,27 @@ } void DeviceRegistrationInfo::GetDeviceInfo( - const CloudRequestCallback& success_callback, - const ErrorCallback& error_callback) { + const CloudRequestDoneCallback& callback) { ErrorPtr error; if (!VerifyRegistrationCredentials(&error)) { - if (!error_callback.is_null()) - error_callback.Run(std::move(error)); - return; + return callback.Run({}, std::move(error)); } - DoCloudRequest(HttpClient::Method::kGet, GetDeviceURL(), nullptr, - success_callback, error_callback); + DoCloudRequest(HttpClient::Method::kGet, GetDeviceURL(), nullptr, callback); } -struct DeviceRegistrationInfo::RegisterCallbacks { - RegisterCallbacks(const SuccessCallback& success, const ErrorCallback& error) - : success_callback{success}, error_callback{error} {} - SuccessCallback success_callback; - ErrorCallback error_callback; -}; - -void DeviceRegistrationInfo::RegisterDeviceError( - const std::shared_ptr<RegisterCallbacks>& callbacks, - ErrorPtr error) { - task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(callbacks->error_callback, base::Passed(&error)), - {}); +void DeviceRegistrationInfo::RegisterDeviceError(const DoneCallback& callback, + ErrorPtr error) { + task_runner_->PostDelayedTask(FROM_HERE, + base::Bind(callback, base::Passed(&error)), {}); } -void DeviceRegistrationInfo::RegisterDevice( - const std::string& ticket_id, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) { - auto callbacks = - std::make_shared<RegisterCallbacks>(success_callback, error_callback); - +void DeviceRegistrationInfo::RegisterDevice(const std::string& ticket_id, + const DoneCallback& callback) { ErrorPtr error; std::unique_ptr<base::DictionaryValue> device_draft = BuildDeviceResource(&error); if (!device_draft) - return RegisterDeviceError(callbacks, std::move(error)); + return RegisterDeviceError(callback, std::move(error)); base::DictionaryValue req_json; req_json.SetString("id", ticket_id); @@ -573,23 +533,23 @@ RequestSender sender{HttpClient::Method::kPatch, url, http_client_}; sender.SetJsonData(req_json); sender.Send(base::Bind(&DeviceRegistrationInfo::RegisterDeviceOnTicketSent, - weak_factory_.GetWeakPtr(), ticket_id, callbacks), - base::Bind(&DeviceRegistrationInfo::RegisterDeviceError, - weak_factory_.GetWeakPtr(), callbacks)); + weak_factory_.GetWeakPtr(), ticket_id, callback)); } void DeviceRegistrationInfo::RegisterDeviceOnTicketSent( const std::string& ticket_id, - const std::shared_ptr<RegisterCallbacks>& callbacks, - const provider::HttpClient::Response& response) { - ErrorPtr error; - auto json_resp = ParseJsonResponse(response, &error); + const DoneCallback& callback, + std::unique_ptr<provider::HttpClient::Response> response, + ErrorPtr error) { + if (error) + return RegisterDeviceError(callback, std::move(error)); + auto json_resp = ParseJsonResponse(*response, &error); if (!json_resp) - return RegisterDeviceError(callbacks, std::move(error)); + return RegisterDeviceError(callback, std::move(error)); - if (!IsSuccessful(response)) { + if (!IsSuccessful(*response)) { ParseGCDError(json_resp.get(), &error); - return RegisterDeviceError(callbacks, std::move(error)); + return RegisterDeviceError(callback, std::move(error)); } std::string url = @@ -597,21 +557,21 @@ {{"key", GetSettings().api_key}}); RequestSender{HttpClient::Method::kPost, url, http_client_}.Send( base::Bind(&DeviceRegistrationInfo::RegisterDeviceOnTicketFinalized, - weak_factory_.GetWeakPtr(), callbacks), - base::Bind(&DeviceRegistrationInfo::RegisterDeviceError, - weak_factory_.GetWeakPtr(), callbacks)); + weak_factory_.GetWeakPtr(), callback)); } void DeviceRegistrationInfo::RegisterDeviceOnTicketFinalized( - const std::shared_ptr<RegisterCallbacks>& callbacks, - const provider::HttpClient::Response& response) { - ErrorPtr error; - auto json_resp = ParseJsonResponse(response, &error); + const DoneCallback& callback, + std::unique_ptr<provider::HttpClient::Response> response, + ErrorPtr error) { + if (error) + return RegisterDeviceError(callback, std::move(error)); + auto json_resp = ParseJsonResponse(*response, &error); if (!json_resp) - return RegisterDeviceError(callbacks, std::move(error)); - if (!IsSuccessful(response)) { + return RegisterDeviceError(callback, std::move(error)); + if (!IsSuccessful(*response)) { ParseGCDError(json_resp.get(), &error); - return RegisterDeviceError(callbacks, std::move(error)); + return RegisterDeviceError(callback, std::move(error)); } std::string auth_code; @@ -624,7 +584,7 @@ !device_draft_response->GetString("id", &cloud_id)) { Error::AddTo(&error, FROM_HERE, kErrorDomainGCD, "unexpected_response", "Device account missing in response"); - return RegisterDeviceError(callbacks, std::move(error)); + return RegisterDeviceError(callback, std::move(error)); } UpdateDeviceInfoTimestamp(*device_draft_response); @@ -641,18 +601,18 @@ {"grant_type", "authorization_code"}}); sender2.Send(base::Bind(&DeviceRegistrationInfo::RegisterDeviceOnAuthCodeSent, weak_factory_.GetWeakPtr(), cloud_id, robot_account, - callbacks), - base::Bind(&DeviceRegistrationInfo::RegisterDeviceError, - weak_factory_.GetWeakPtr(), callbacks)); + callback)); } void DeviceRegistrationInfo::RegisterDeviceOnAuthCodeSent( const std::string& cloud_id, const std::string& robot_account, - const std::shared_ptr<RegisterCallbacks>& callbacks, - const provider::HttpClient::Response& response) { - ErrorPtr error; - auto json_resp = ParseOAuthResponse(response, &error); + const DoneCallback& callback, + std::unique_ptr<provider::HttpClient::Response> response, + ErrorPtr error) { + if (error) + return RegisterDeviceError(callback, std::move(error)); + auto json_resp = ParseOAuthResponse(*response, &error); int expires_in = 0; std::string refresh_token; if (!json_resp || !json_resp->GetString("access_token", &access_token_) || @@ -661,7 +621,7 @@ access_token_.empty() || refresh_token.empty() || expires_in <= 0) { Error::AddTo(&error, FROM_HERE, kErrorDomainGCD, "unexpected_response", "Device access_token missing in response"); - return RegisterDeviceError(callbacks, std::move(error)); + return RegisterDeviceError(callback, std::move(error)); } access_token_expiration_ = @@ -673,7 +633,7 @@ change.set_refresh_token(refresh_token); change.Commit(); - task_runner_->PostDelayedTask(FROM_HERE, callbacks->success_callback, {}); + task_runner_->PostDelayedTask(FROM_HERE, base::Bind(callback, nullptr), {}); StartNotificationChannel(); @@ -686,10 +646,9 @@ HttpClient::Method method, const std::string& url, const base::DictionaryValue* body, - const CloudRequestCallback& success_callback, - const ErrorCallback& error_callback) { + const CloudRequestDoneCallback& callback) { // We make CloudRequestData shared here because we want to make sure - // there is only one instance of success_callback and error_calback since + // there is only one instance of callback and error_calback since // those may have move-only types and making a copy of the callback with // move-only types curried-in will invalidate the source callback. auto data = std::make_shared<CloudRequestData>(); @@ -697,8 +656,7 @@ data->url = url; if (body) base::JSONWriter::Write(*body, &data->body); - data->success_callback = success_callback; - data->error_callback = error_callback; + data->callback = callback; SendCloudRequest(data); } @@ -710,7 +668,7 @@ ErrorPtr error; if (!VerifyRegistrationCredentials(&error)) { - return data->error_callback.Run(std::move(error)); + return data->callback.Run({}, std::move(error)); } if (cloud_backoff_entry_->ShouldRejectRequest()) { @@ -726,22 +684,21 @@ RequestSender sender{data->method, data->url, http_client_}; sender.SetData(data->body, http::kJsonUtf8); sender.SetAccessToken(access_token_); - sender.Send(base::Bind(&DeviceRegistrationInfo::OnCloudRequestSuccess, - AsWeakPtr(), data), - base::Bind(&DeviceRegistrationInfo::OnCloudRequestError, + sender.Send(base::Bind(&DeviceRegistrationInfo::OnCloudRequestDone, AsWeakPtr(), data)); } -void DeviceRegistrationInfo::OnCloudRequestSuccess( +void DeviceRegistrationInfo::OnCloudRequestDone( const std::shared_ptr<const CloudRequestData>& data, - const HttpClient::Response& response) { - int status_code = response.GetStatusCode(); + std::unique_ptr<provider::HttpClient::Response> response, + ErrorPtr error) { + if (error) + return RetryCloudRequest(data); + int status_code = response->GetStatusCode(); if (status_code == http::kDenied) { cloud_backoff_entry_->InformOfRequest(true); RefreshAccessToken( base::Bind(&DeviceRegistrationInfo::OnAccessTokenRefreshed, AsWeakPtr(), - data), - base::Bind(&DeviceRegistrationInfo::OnAccessTokenError, AsWeakPtr(), data)); return; } @@ -755,14 +712,13 @@ return; } - ErrorPtr error; - auto json_resp = ParseJsonResponse(response, &error); + auto json_resp = ParseJsonResponse(*response, &error); if (!json_resp) { cloud_backoff_entry_->InformOfRequest(true); - return data->error_callback.Run(std::move(error)); + return data->callback.Run({}, std::move(error)); } - if (!IsSuccessful(response)) { + if (!IsSuccessful(*response)) { ParseGCDError(json_resp.get(), &error); if (status_code == http::kForbidden && error->HasError(kErrorDomainGCDServer, "rateLimitExceeded")) { @@ -770,18 +726,12 @@ return RetryCloudRequest(data); } cloud_backoff_entry_->InformOfRequest(true); - return data->error_callback.Run(std::move(error)); + return data->callback.Run({}, std::move(error)); } cloud_backoff_entry_->InformOfRequest(true); SetGcdState(GcdState::kConnected); - data->success_callback.Run(*json_resp); -} - -void DeviceRegistrationInfo::OnCloudRequestError( - const std::shared_ptr<const CloudRequestData>& data, - ErrorPtr error) { - RetryCloudRequest(data); + data->callback.Run(*json_resp, nullptr); } void DeviceRegistrationInfo::RetryCloudRequest( @@ -793,32 +743,34 @@ } void DeviceRegistrationInfo::OnAccessTokenRefreshed( - const std::shared_ptr<const CloudRequestData>& data) { + const std::shared_ptr<const CloudRequestData>& data, + ErrorPtr error) { + if (error) { + CheckAccessTokenError(error->Clone()); + return data->callback.Run({}, std::move(error)); + } SendCloudRequest(data); } -void DeviceRegistrationInfo::OnAccessTokenError( - const std::shared_ptr<const CloudRequestData>& data, - ErrorPtr error) { - CheckAccessTokenError(error->Clone()); - data->error_callback.Run(std::move(error)); -} - void DeviceRegistrationInfo::CheckAccessTokenError(ErrorPtr error) { - if (error->HasError(kErrorDomainOAuth2, "invalid_grant")) + if (error && error->HasError(kErrorDomainOAuth2, "invalid_grant")) MarkDeviceUnregistered(); } -void DeviceRegistrationInfo::ConnectToCloud() { +void DeviceRegistrationInfo::ConnectToCloud(ErrorPtr error) { + if (error) { + if (error->HasError(kErrorDomainOAuth2, "invalid_grant")) + MarkDeviceUnregistered(); + return; + } + connected_to_cloud_ = false; if (!VerifyRegistrationCredentials(nullptr)) return; if (access_token_.empty()) { RefreshAccessToken( - base::Bind(&DeviceRegistrationInfo::ConnectToCloud, AsWeakPtr()), - base::Bind(&DeviceRegistrationInfo::CheckAccessTokenError, - AsWeakPtr())); + base::Bind(&DeviceRegistrationInfo::ConnectToCloud, AsWeakPtr())); return; } @@ -828,16 +780,16 @@ // 3) abort any commands that we've previously marked as "in progress" // or as being in an error state; publish queued commands UpdateDeviceResource( - base::Bind(&DeviceRegistrationInfo::OnConnectedToCloud, AsWeakPtr()), - base::Bind(&IgnoreCloudError)); + base::Bind(&DeviceRegistrationInfo::OnConnectedToCloud, AsWeakPtr())); } -void DeviceRegistrationInfo::OnConnectedToCloud() { +void DeviceRegistrationInfo::OnConnectedToCloud(ErrorPtr error) { + if (error) + return; LOG(INFO) << "Device connected to cloud server"; connected_to_cloud_ = true; FetchCommands(base::Bind(&DeviceRegistrationInfo::ProcessInitialCommandList, - AsWeakPtr()), - base::Bind(&IgnoreCloudError)); + AsWeakPtr())); // In case there are any pending state updates since we sent off the initial // UpdateDeviceResource() request, update the server with any state changes. PublishStateUpdates(); @@ -853,8 +805,7 @@ change.Commit(); if (HaveRegistrationCredentials()) { - UpdateDeviceResource(base::Bind(&base::DoNothing), - base::Bind(&IgnoreCloudError)); + UpdateDeviceResource(base::Bind(&IgnoreCloudError)); } } @@ -891,12 +842,10 @@ void DeviceRegistrationInfo::UpdateCommand( const std::string& command_id, const base::DictionaryValue& command_patch, - const base::Closure& on_success, - const base::Closure& on_error) { + const DoneCallback& callback) { DoCloudRequest(HttpClient::Method::kPatch, GetServiceURL("commands/" + command_id), &command_patch, - base::Bind(&IgnoreCloudResultWithCallback, on_success), - base::Bind(&IgnoreCloudErrorWithCallback, on_error)); + base::Bind(&IgnoreCloudResultWithCallback, callback)); } void DeviceRegistrationInfo::NotifyCommandAborted(const std::string& command_id, @@ -908,14 +857,12 @@ command_patch.Set(commands::attributes::kCommand_Error, ErrorInfoToJson(*error).release()); } - UpdateCommand(command_id, command_patch, base::Bind(&base::DoNothing), - base::Bind(&base::DoNothing)); + UpdateCommand(command_id, command_patch, base::Bind(&IgnoreCloudError)); } void DeviceRegistrationInfo::UpdateDeviceResource( - const base::Closure& on_success, - const ErrorCallback& on_failure) { - queued_resource_update_callbacks_.emplace_back(on_success, on_failure); + const DoneCallback& callback) { + queued_resource_update_callbacks_.emplace_back(callback); if (!in_progress_resource_update_callbacks_.empty()) { VLOG(1) << "Another request is already pending."; return; @@ -935,10 +882,8 @@ // the request to guard against out-of-order requests overwriting settings // specified by later requests. VLOG(1) << "Getting the last device resource timestamp from server..."; - GetDeviceInfo( - base::Bind(&DeviceRegistrationInfo::OnDeviceInfoRetrieved, AsWeakPtr()), - base::Bind(&DeviceRegistrationInfo::OnUpdateDeviceResourceError, - AsWeakPtr())); + GetDeviceInfo(base::Bind(&DeviceRegistrationInfo::OnDeviceInfoRetrieved, + AsWeakPtr())); return; } @@ -959,16 +904,16 @@ std::string url = GetDeviceURL( {}, {{"lastUpdateTimeMs", last_device_resource_updated_timestamp_}}); - DoCloudRequest( - HttpClient::Method::kPut, url, device_resource.get(), - base::Bind(&DeviceRegistrationInfo::OnUpdateDeviceResourceSuccess, - AsWeakPtr()), - base::Bind(&DeviceRegistrationInfo::OnUpdateDeviceResourceError, - AsWeakPtr())); + DoCloudRequest(HttpClient::Method::kPut, url, device_resource.get(), + base::Bind(&DeviceRegistrationInfo::OnUpdateDeviceResourceDone, + AsWeakPtr())); } void DeviceRegistrationInfo::OnDeviceInfoRetrieved( - const base::DictionaryValue& device_info) { + const base::DictionaryValue& device_info, + ErrorPtr error) { + if (error) + return OnUpdateDeviceResourceError(std::move(error)); if (UpdateDeviceInfoTimestamp(device_info)) StartQueuedUpdateDeviceResource(); } @@ -987,15 +932,18 @@ return true; } -void DeviceRegistrationInfo::OnUpdateDeviceResourceSuccess( - const base::DictionaryValue& device_info) { +void DeviceRegistrationInfo::OnUpdateDeviceResourceDone( + const base::DictionaryValue& device_info, + ErrorPtr error) { + if (error) + return OnUpdateDeviceResourceError(std::move(error)); UpdateDeviceInfoTimestamp(device_info); // Make a copy of the callback list so that if the callback triggers another // call to UpdateDeviceResource(), we do not modify the list we are iterating // over. auto callback_list = std::move(in_progress_resource_update_callbacks_); - for (const auto& callback_pair : callback_list) - callback_pair.first.Run(); + for (const auto& callback : callback_list) + callback.Run(nullptr); StartQueuedUpdateDeviceResource(); } @@ -1004,10 +952,8 @@ // If the server rejected our previous request, retrieve the latest // timestamp from the server and retry. VLOG(1) << "Getting the last device resource timestamp from server..."; - GetDeviceInfo( - base::Bind(&DeviceRegistrationInfo::OnDeviceInfoRetrieved, AsWeakPtr()), - base::Bind(&DeviceRegistrationInfo::OnUpdateDeviceResourceError, - AsWeakPtr())); + GetDeviceInfo(base::Bind(&DeviceRegistrationInfo::OnDeviceInfoRetrieved, + AsWeakPtr())); return; } @@ -1015,28 +961,24 @@ // call to UpdateDeviceResource(), we do not modify the list we are iterating // over. auto callback_list = std::move(in_progress_resource_update_callbacks_); - for (const auto& callback_pair : callback_list) - callback_pair.second.Run(error->Clone()); + for (const auto& callback : callback_list) + callback.Run(error->Clone()); StartQueuedUpdateDeviceResource(); } -void DeviceRegistrationInfo::OnFetchCommandsSuccess( - const base::Callback<void(const base::ListValue&)>& callback, - const base::DictionaryValue& json) { +void DeviceRegistrationInfo::OnFetchCommandsDone( + const base::Callback<void(const base::ListValue&, ErrorPtr)>& callback, + const base::DictionaryValue& json, + ErrorPtr error) { OnFetchCommandsReturned(); + if (error) + return callback.Run({}, std::move(error)); const base::ListValue* commands{nullptr}; - if (!json.GetList("commands", &commands)) { + if (!json.GetList("commands", &commands)) VLOG(2) << "No commands in the response."; - } const base::ListValue empty; - callback.Run(commands ? *commands : empty); -} - -void DeviceRegistrationInfo::OnFetchCommandsError(const ErrorCallback& callback, - ErrorPtr error) { - OnFetchCommandsReturned(); - callback.Run(std::move(error)); + callback.Run(commands ? *commands : empty, nullptr); } void DeviceRegistrationInfo::OnFetchCommandsReturned() { @@ -1047,17 +989,15 @@ } void DeviceRegistrationInfo::FetchCommands( - const base::Callback<void(const base::ListValue&)>& on_success, - const ErrorCallback& on_failure) { + const base::Callback<void(const base::ListValue&, ErrorPtr error)>& + callback) { fetch_commands_request_sent_ = true; fetch_commands_request_queued_ = false; DoCloudRequest( HttpClient::Method::kGet, GetServiceURL("commands/queue", {{"deviceId", GetSettings().cloud_id}}), - nullptr, base::Bind(&DeviceRegistrationInfo::OnFetchCommandsSuccess, - AsWeakPtr(), on_success), - base::Bind(&DeviceRegistrationInfo::OnFetchCommandsError, AsWeakPtr(), - on_failure)); + nullptr, base::Bind(&DeviceRegistrationInfo::OnFetchCommandsDone, + AsWeakPtr(), callback)); } void DeviceRegistrationInfo::FetchAndPublishCommands() { @@ -1067,12 +1007,14 @@ } FetchCommands(base::Bind(&DeviceRegistrationInfo::PublishCommands, - weak_factory_.GetWeakPtr()), - base::Bind(&IgnoreCloudError)); + weak_factory_.GetWeakPtr())); } void DeviceRegistrationInfo::ProcessInitialCommandList( - const base::ListValue& commands) { + const base::ListValue& commands, + ErrorPtr error) { + if (error) + return; for (const base::Value* command : commands) { const base::DictionaryValue* command_dict{nullptr}; if (!command->GetAsDictionary(&command_dict)) { @@ -1098,8 +1040,7 @@ // TODO(wiley) We could consider handling this error case more gracefully. DoCloudRequest(HttpClient::Method::kPut, GetServiceURL("commands/" + command_id), cmd_copy.get(), - base::Bind(&IgnoreCloudResult), - base::Bind(&IgnoreCloudError)); + base::Bind(&IgnoreCloudResult)); } else { // Normal command, publish it to local clients. PublishCommand(*command_dict); @@ -1107,7 +1048,10 @@ } } -void DeviceRegistrationInfo::PublishCommands(const base::ListValue& commands) { +void DeviceRegistrationInfo::PublishCommands(const base::ListValue& commands, + ErrorPtr error) { + if (error) + return; for (const base::Value* command : commands) { const base::DictionaryValue* command_dict{nullptr}; if (!command->GetAsDictionary(&command_dict)) { @@ -1188,28 +1132,26 @@ body.Set("patches", patches.release()); device_state_update_pending_ = true; - DoCloudRequest( - HttpClient::Method::kPost, GetDeviceURL("patchState"), &body, - base::Bind(&DeviceRegistrationInfo::OnPublishStateSuccess, AsWeakPtr(), - update_id), - base::Bind(&DeviceRegistrationInfo::OnPublishStateError, AsWeakPtr())); + DoCloudRequest(HttpClient::Method::kPost, GetDeviceURL("patchState"), &body, + base::Bind(&DeviceRegistrationInfo::OnPublishStateDone, + AsWeakPtr(), update_id)); } -void DeviceRegistrationInfo::OnPublishStateSuccess( +void DeviceRegistrationInfo::OnPublishStateDone( StateChangeQueueInterface::UpdateID update_id, - const base::DictionaryValue& reply) { + const base::DictionaryValue& reply, + ErrorPtr error) { device_state_update_pending_ = false; + if (error) { + LOG(ERROR) << "Permanent failure while trying to update device state"; + return; + } state_manager_->NotifyStateUpdatedOnServer(update_id); // See if there were more pending state updates since the previous request // had been sent out. PublishStateUpdates(); } -void DeviceRegistrationInfo::OnPublishStateError(ErrorPtr error) { - LOG(ERROR) << "Permanent failure while trying to update device state"; - device_state_update_pending_ = false; -} - void DeviceRegistrationInfo::SetGcdState(GcdState new_state) { VLOG_IF(1, new_state != gcd_state_) << "Changing registration status to " << EnumToString(new_state); @@ -1223,8 +1165,7 @@ if (!HaveRegistrationCredentials() || !connected_to_cloud_) return; - UpdateDeviceResource(base::Bind(&base::DoNothing), - base::Bind(&IgnoreCloudError)); + UpdateDeviceResource(base::Bind(&IgnoreCloudError)); } void DeviceRegistrationInfo::OnStateChanged() { @@ -1257,8 +1198,9 @@ // the moment of the last poll and the time we successfully told the server // to send new commands over the new notification channel. UpdateDeviceResource( - base::Bind(&DeviceRegistrationInfo::FetchAndPublishCommands, AsWeakPtr()), - base::Bind(&IgnoreCloudError)); + base::Bind(&IgnoreCloudErrorWithCallback, + base::Bind(&DeviceRegistrationInfo::FetchAndPublishCommands, + AsWeakPtr()))); } void DeviceRegistrationInfo::OnDisconnected() { @@ -1269,15 +1211,13 @@ pull_channel_->UpdatePullInterval( base::TimeDelta::FromSeconds(kPollingPeriodSeconds)); current_notification_channel_ = pull_channel_.get(); - UpdateDeviceResource(base::Bind(&base::DoNothing), - base::Bind(&IgnoreCloudError)); + UpdateDeviceResource(base::Bind(&IgnoreCloudError)); } void DeviceRegistrationInfo::OnPermanentFailure() { LOG(ERROR) << "Failed to establish notification channel."; notification_channel_starting_ = false; RefreshAccessToken( - base::Bind(&base::DoNothing), base::Bind(&DeviceRegistrationInfo::CheckAccessTokenError, AsWeakPtr())); }
diff --git a/libweave/src/device_registration_info.h b/libweave/src/device_registration_info.h index a7f79b6..d389290 100644 --- a/libweave/src/device_registration_info.h +++ b/libweave/src/device_registration_info.h
@@ -50,8 +50,9 @@ class DeviceRegistrationInfo : public NotificationDelegate, public CloudCommandUpdateInterface { public: - using CloudRequestCallback = - base::Callback<void(const base::DictionaryValue& response)>; + using CloudRequestDoneCallback = + base::Callback<void(const base::DictionaryValue& response, + ErrorPtr error)>; DeviceRegistrationInfo(const std::shared_ptr<CommandManager>& command_manager, const std::shared_ptr<StateManager>& state_manager, @@ -65,8 +66,7 @@ void AddGcdStateChangedCallback( const Device::GcdStateChangedCallback& callback); void RegisterDevice(const std::string& ticket_id, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback); + const DoneCallback& callback); void UpdateDeviceInfo(const std::string& name, const std::string& description, @@ -81,8 +81,7 @@ const std::string& service_url, ErrorPtr* error); - void GetDeviceInfo(const CloudRequestCallback& success_callback, - const ErrorCallback& error_callback); + void GetDeviceInfo(const CloudRequestDoneCallback& callback); // Returns the GCD service request URL. If |subpath| is specified, it is // appended to the base URL which is normally @@ -120,8 +119,7 @@ // Updates a command (override from CloudCommandUpdateInterface). void UpdateCommand(const std::string& command_id, const base::DictionaryValue& command_patch, - const base::Closure& on_success, - const base::Closure& on_error) override; + const DoneCallback& callback) override; // TODO(vitalybuka): remove getters and pass config to dependent code. const Config::Settings& GetSettings() const { return config_->GetSettings(); } @@ -143,22 +141,17 @@ // Initiates the connection to the cloud server. // Device will do required start up chores and then start to listen // to new commands. - void ConnectToCloud(); + void ConnectToCloud(ErrorPtr error); // Notification called when ConnectToCloud() succeeds. - void OnConnectedToCloud(); + void OnConnectedToCloud(ErrorPtr error); // Forcibly refreshes the access token. - void RefreshAccessToken(const base::Closure& success_callback, - const ErrorCallback& error_callback); + void RefreshAccessToken(const DoneCallback& callback); // Callbacks for RefreshAccessToken(). - void OnRefreshAccessTokenSuccess( - const std::shared_ptr<base::Closure>& success_callback, - const std::shared_ptr<ErrorCallback>& error_callback, - const provider::HttpClient::Response& response); - void OnRefreshAccessTokenError( - const std::shared_ptr<base::Closure>& success_callback, - const std::shared_ptr<ErrorCallback>& error_callback, + void OnRefreshAccessTokenDone( + const DoneCallback& callback, + std::unique_ptr<provider::HttpClient::Response> response, ErrorPtr error); // Parse the OAuth response, and sets registration status to @@ -179,40 +172,36 @@ void DoCloudRequest(provider::HttpClient::Method method, const std::string& url, const base::DictionaryValue* body, - const CloudRequestCallback& success_callback, - const ErrorCallback& error_callback); + const CloudRequestDoneCallback& callback); // Helper for DoCloudRequest(). struct CloudRequestData { provider::HttpClient::Method method; std::string url; std::string body; - CloudRequestCallback success_callback; - ErrorCallback error_callback; + CloudRequestDoneCallback callback; }; void SendCloudRequest(const std::shared_ptr<const CloudRequestData>& data); - void OnCloudRequestSuccess( + void OnCloudRequestDone( const std::shared_ptr<const CloudRequestData>& data, - const provider::HttpClient::Response& response); - void OnCloudRequestError(const std::shared_ptr<const CloudRequestData>& data, - ErrorPtr error); + std::unique_ptr<provider::HttpClient::Response> response, + ErrorPtr error); void RetryCloudRequest(const std::shared_ptr<const CloudRequestData>& data); void OnAccessTokenRefreshed( - const std::shared_ptr<const CloudRequestData>& data); - void OnAccessTokenError(const std::shared_ptr<const CloudRequestData>& data, - ErrorPtr error); + const std::shared_ptr<const CloudRequestData>& data, + ErrorPtr error); void CheckAccessTokenError(ErrorPtr error); - void UpdateDeviceResource(const base::Closure& on_success, - const ErrorCallback& on_failure); + void UpdateDeviceResource(const DoneCallback& callback); void StartQueuedUpdateDeviceResource(); - // Success/failure callbacks for UpdateDeviceResource(). - void OnUpdateDeviceResourceSuccess(const base::DictionaryValue& device_info); + void OnUpdateDeviceResourceDone(const base::DictionaryValue& device_info, + ErrorPtr error); void OnUpdateDeviceResourceError(ErrorPtr error); // Callback from GetDeviceInfo() to retrieve the device resource timestamp // and retry UpdateDeviceResource() call. - void OnDeviceInfoRetrieved(const base::DictionaryValue& device_info); + void OnDeviceInfoRetrieved(const base::DictionaryValue& device_info, + ErrorPtr error); // Extracts the timestamp from the device resource and sets it to // |last_device_resource_updated_timestamp_|. @@ -221,13 +210,11 @@ bool UpdateDeviceInfoTimestamp(const base::DictionaryValue& device_info); void FetchCommands( - const base::Callback<void(const base::ListValue&)>& on_success, - const ErrorCallback& on_failure); - // Success/failure callbacks for FetchCommands(). - void OnFetchCommandsSuccess( - const base::Callback<void(const base::ListValue&)>& callback, - const base::DictionaryValue& json); - void OnFetchCommandsError(const ErrorCallback& callback, ErrorPtr error); + const base::Callback<void(const base::ListValue&, ErrorPtr)>& callback); + void OnFetchCommandsDone( + const base::Callback<void(const base::ListValue&, ErrorPtr)>& callback, + const base::DictionaryValue& json, + ErrorPtr); // Called when FetchCommands completes (with either success or error). // This method reschedules any pending/queued fetch requests. void OnFetchCommandsReturned(); @@ -235,9 +222,10 @@ // Processes the command list that is fetched from the server on connection. // Aborts commands which are in transitional states and publishes queued // commands which are queued. - void ProcessInitialCommandList(const base::ListValue& commands); + void ProcessInitialCommandList(const base::ListValue& commands, + ErrorPtr error); - void PublishCommands(const base::ListValue& commands); + void PublishCommands(const base::ListValue& commands, ErrorPtr error); void PublishCommand(const base::DictionaryValue& command); // Helper function to pull the pending command list from the server using @@ -245,8 +233,9 @@ void FetchAndPublishCommands(); void PublishStateUpdates(); - void OnPublishStateSuccess(StateChangeQueueInterface::UpdateID update_id, - const base::DictionaryValue& reply); + void OnPublishStateDone(StateChangeQueueInterface::UpdateID update_id, + const base::DictionaryValue& reply, + ErrorPtr error); void OnPublishStateError(ErrorPtr error); // If unrecoverable error occurred (e.g. error parsing command instance), @@ -275,21 +264,22 @@ // Wipes out the device registration information and stops server connections. void MarkDeviceUnregistered(); - struct RegisterCallbacks; - void RegisterDeviceError(const std::shared_ptr<RegisterCallbacks>& callbacks, - ErrorPtr error); + void RegisterDeviceError(const DoneCallback& callback, ErrorPtr error); void RegisterDeviceOnTicketSent( const std::string& ticket_id, - const std::shared_ptr<RegisterCallbacks>& callbacks, - const provider::HttpClient::Response& response); + const DoneCallback& callback, + std::unique_ptr<provider::HttpClient::Response> response, + ErrorPtr error); void RegisterDeviceOnTicketFinalized( - const std::shared_ptr<RegisterCallbacks>& callbacks, - const provider::HttpClient::Response& response); + const DoneCallback& callback, + std::unique_ptr<provider::HttpClient::Response> response, + ErrorPtr error); void RegisterDeviceOnAuthCodeSent( const std::string& cloud_id, const std::string& robot_account, - const std::shared_ptr<RegisterCallbacks>& callbacks, - const provider::HttpClient::Response& response); + const DoneCallback& callback, + std::unique_ptr<provider::HttpClient::Response> response, + ErrorPtr error); // Transient data std::string access_token_; @@ -327,13 +317,12 @@ // another one was in flight. bool fetch_commands_request_queued_{false}; - using ResourceUpdateCallbackList = - std::vector<std::pair<base::Closure, ErrorCallback>>; - // Success/error callbacks for device resource update request currently in - // flight to the cloud server. + using ResourceUpdateCallbackList = std::vector<DoneCallback>; + // Callbacks for device resource update request currently in flight to the + // cloud server. ResourceUpdateCallbackList in_progress_resource_update_callbacks_; - // Success/error callbacks for device resource update requests queued while - // another request is in flight to the cloud server. + // Callbacks for device resource update requests queued while another request + // is in flight to the cloud server. ResourceUpdateCallbackList queued_resource_update_callbacks_; std::unique_ptr<NotificationChannel> primary_notification_channel_;
diff --git a/libweave/src/device_registration_info_unittest.cc b/libweave/src/device_registration_info_unittest.cc index 6d00a51..b086f8b 100644 --- a/libweave/src/device_registration_info_unittest.cc +++ b/libweave/src/device_registration_info_unittest.cc
@@ -166,18 +166,19 @@ } void PublishCommands(const base::ListValue& commands) { - return dev_reg_->PublishCommands(commands); + dev_reg_->PublishCommands(commands, nullptr); } bool RefreshAccessToken(ErrorPtr* error) const { bool succeeded = false; - auto on_success = [&succeeded]() { succeeded = true; }; - auto on_failure = [&error](ErrorPtr in_error) { - if (error) + auto callback = [&succeeded, &error](ErrorPtr in_error) { + if (error) { *error = std::move(in_error); + return; + } + succeeded = true; }; - dev_reg_->RefreshAccessToken(base::Bind(on_success), - base::Bind(on_failure)); + dev_reg_->RefreshAccessToken(base::Bind(callback)); return succeeded; } @@ -236,10 +237,10 @@ EXPECT_CALL( http_client_, SendRequest(HttpClient::Method::kPost, dev_reg_->GetOAuthURL("token"), - HttpClient::Headers{GetFormHeader()}, _, _, _)) + HttpClient::Headers{GetFormHeader()}, _, _)) .WillOnce(WithArgs<3, 4>(Invoke([]( const std::string& data, - const HttpClient::SuccessCallback& callback) { + const HttpClient::SendRequestCallback& callback) { EXPECT_EQ("refresh_token", GetFormField(data, "grant_type")); EXPECT_EQ(test_data::kRefreshToken, GetFormField(data, "refresh_token")); @@ -251,7 +252,7 @@ json.SetString("access_token", test_data::kAccessToken); json.SetInteger("expires_in", 3600); - callback.Run(*ReplyWithJson(200, json)); + callback.Run(ReplyWithJson(200, json), nullptr); }))); EXPECT_TRUE(RefreshAccessToken(nullptr)); @@ -265,10 +266,10 @@ EXPECT_CALL( http_client_, SendRequest(HttpClient::Method::kPost, dev_reg_->GetOAuthURL("token"), - HttpClient::Headers{GetFormHeader()}, _, _, _)) + HttpClient::Headers{GetFormHeader()}, _, _)) .WillOnce(WithArgs<3, 4>(Invoke([]( const std::string& data, - const HttpClient::SuccessCallback& callback) { + const HttpClient::SendRequestCallback& callback) { EXPECT_EQ("refresh_token", GetFormField(data, "grant_type")); EXPECT_EQ(test_data::kRefreshToken, GetFormField(data, "refresh_token")); @@ -278,7 +279,7 @@ base::DictionaryValue json; json.SetString("error", "unable_to_authenticate"); - callback.Run(*ReplyWithJson(400, json)); + callback.Run(ReplyWithJson(400, json), nullptr); }))); ErrorPtr error; @@ -294,10 +295,10 @@ EXPECT_CALL( http_client_, SendRequest(HttpClient::Method::kPost, dev_reg_->GetOAuthURL("token"), - HttpClient::Headers{GetFormHeader()}, _, _, _)) + HttpClient::Headers{GetFormHeader()}, _, _)) .WillOnce(WithArgs<3, 4>(Invoke([]( const std::string& data, - const HttpClient::SuccessCallback& callback) { + const HttpClient::SendRequestCallback& callback) { EXPECT_EQ("refresh_token", GetFormField(data, "grant_type")); EXPECT_EQ(test_data::kRefreshToken, GetFormField(data, "refresh_token")); @@ -307,7 +308,7 @@ base::DictionaryValue json; json.SetString("error", "invalid_grant"); - callback.Run(*ReplyWithJson(400, json)); + callback.Run(ReplyWithJson(400, json), nullptr); }))); ErrorPtr error; @@ -320,30 +321,31 @@ ReloadSettings(); SetAccessToken(); - EXPECT_CALL(http_client_, - SendRequest(HttpClient::Method::kGet, dev_reg_->GetDeviceURL(), - HttpClient::Headers{GetAuthHeader(), GetJsonHeader()}, - _, _, _)) + EXPECT_CALL( + http_client_, + SendRequest(HttpClient::Method::kGet, dev_reg_->GetDeviceURL(), + HttpClient::Headers{GetAuthHeader(), GetJsonHeader()}, _, _)) .WillOnce(WithArgs<3, 4>( Invoke([](const std::string& data, - const HttpClient::SuccessCallback& callback) { + const HttpClient::SendRequestCallback& callback) { base::DictionaryValue json; json.SetString("channel.supportedType", "xmpp"); json.SetString("deviceKind", "vendor"); json.SetString("id", test_data::kDeviceId); json.SetString("kind", "clouddevices#device"); - callback.Run(*ReplyWithJson(200, json)); + callback.Run(ReplyWithJson(200, json), nullptr); }))); bool succeeded = false; - auto on_success = [&succeeded, this](const base::DictionaryValue& info) { + auto callback = [&succeeded, this](const base::DictionaryValue& info, + ErrorPtr error) { + EXPECT_FALSE(error); std::string id; EXPECT_TRUE(info.GetString("id", &id)); EXPECT_EQ(test_data::kDeviceId, id); succeeded = true; }; - auto on_failure = [](ErrorPtr error) { FAIL() << "Should not be called"; }; - dev_reg_->GetDeviceInfo(base::Bind(on_success), base::Bind(on_failure)); + dev_reg_->GetDeviceInfo(base::Bind(callback)); EXPECT_TRUE(succeeded); } @@ -386,10 +388,10 @@ EXPECT_CALL(http_client_, SendRequest(HttpClient::Method::kPatch, ticket_url + "?key=" + test_data::kApiKey, - HttpClient::Headers{GetJsonHeader()}, _, _, _)) + HttpClient::Headers{GetJsonHeader()}, _, _)) .WillOnce(WithArgs<3, 4>(Invoke([]( const std::string& data, - const HttpClient::SuccessCallback& callback) { + const HttpClient::SendRequestCallback& callback) { auto json = test::CreateDictionaryValue(data); EXPECT_NE(nullptr, json.get()); std::string value; @@ -449,15 +451,15 @@ device_draft->SetString("kind", "clouddevices#device"); json_resp.Set("deviceDraft", device_draft); - callback.Run(*ReplyWithJson(200, json_resp)); + callback.Run(ReplyWithJson(200, json_resp), nullptr); }))); EXPECT_CALL(http_client_, SendRequest(HttpClient::Method::kPost, ticket_url + "/finalize?key=" + test_data::kApiKey, - HttpClient::Headers{}, _, _, _)) - .WillOnce( - WithArgs<4>(Invoke([](const HttpClient::SuccessCallback& callback) { + HttpClient::Headers{}, _, _)) + .WillOnce(WithArgs<4>( + Invoke([](const HttpClient::SendRequestCallback& callback) { base::DictionaryValue json; json.SetString("id", test_data::kClaimTicketId); json.SetString("kind", "clouddevices#registrationTicket"); @@ -469,16 +471,16 @@ json.SetString("robotAccountEmail", test_data::kRobotAccountEmail); json.SetString("robotAccountAuthorizationCode", test_data::kRobotAccountAuthCode); - callback.Run(*ReplyWithJson(200, json)); + callback.Run(ReplyWithJson(200, json), nullptr); }))); EXPECT_CALL( http_client_, SendRequest(HttpClient::Method::kPost, dev_reg_->GetOAuthURL("token"), - HttpClient::Headers{GetFormHeader()}, _, _, _)) + HttpClient::Headers{GetFormHeader()}, _, _)) .WillOnce(WithArgs<3, 4>(Invoke([]( const std::string& data, - const HttpClient::SuccessCallback& callback) { + const HttpClient::SendRequestCallback& callback) { EXPECT_EQ("authorization_code", GetFormField(data, "grant_type")); EXPECT_EQ(test_data::kRobotAccountAuthCode, GetFormField(data, "code")); EXPECT_EQ(test_data::kClientId, GetFormField(data, "client_id")); @@ -495,12 +497,13 @@ json.SetString("refresh_token", test_data::kRefreshToken); json.SetInteger("expires_in", 3600); - callback.Run(*ReplyWithJson(200, json)); + callback.Run(ReplyWithJson(200, json), nullptr); }))); bool done = false; dev_reg_->RegisterDevice( - test_data::kClaimTicketId, base::Bind([this, &done]() { + test_data::kClaimTicketId, base::Bind([this, &done](ErrorPtr error) { + EXPECT_FALSE(error); done = true; task_runner_.Break(); EXPECT_EQ(GcdState::kConnecting, GetGcdState()); @@ -511,8 +514,7 @@ dev_reg_->GetSettings().refresh_token); EXPECT_EQ(test_data::kRobotAccountEmail, dev_reg_->GetSettings().robot_account); - }), - base::Bind([](ErrorPtr error) { ADD_FAILURE(); })); + })); task_runner_.Run(); EXPECT_TRUE(done); } @@ -573,51 +575,51 @@ }; TEST_F(DeviceRegistrationInfoUpdateCommandTest, SetProgress) { - EXPECT_CALL(http_client_, - SendRequest(HttpClient::Method::kPatch, command_url_, - HttpClient::Headers{GetAuthHeader(), GetJsonHeader()}, - _, _, _)) + EXPECT_CALL( + http_client_, + SendRequest(HttpClient::Method::kPatch, command_url_, + HttpClient::Headers{GetAuthHeader(), GetJsonHeader()}, _, _)) .WillOnce(WithArgs<3, 4>(Invoke([]( const std::string& data, - const HttpClient::SuccessCallback& callback) { + const HttpClient::SendRequestCallback& callback) { EXPECT_JSON_EQ((R"({"state":"inProgress","progress":{"progress":18}})"), *CreateDictionaryValue(data)); base::DictionaryValue json; - callback.Run(*ReplyWithJson(200, json)); + callback.Run(ReplyWithJson(200, json), nullptr); }))); EXPECT_TRUE(command_->SetProgress(*CreateDictionaryValue("{'progress':18}"), nullptr)); } TEST_F(DeviceRegistrationInfoUpdateCommandTest, Complete) { - EXPECT_CALL(http_client_, - SendRequest(HttpClient::Method::kPatch, command_url_, - HttpClient::Headers{GetAuthHeader(), GetJsonHeader()}, - _, _, _)) + EXPECT_CALL( + http_client_, + SendRequest(HttpClient::Method::kPatch, command_url_, + HttpClient::Headers{GetAuthHeader(), GetJsonHeader()}, _, _)) .WillOnce(WithArgs<3, 4>(Invoke([]( const std::string& data, - const HttpClient::SuccessCallback& callback) { + const HttpClient::SendRequestCallback& callback) { EXPECT_JSON_EQ(R"({"state":"done", "results":{"status":"Ok"}})", *CreateDictionaryValue(data)); base::DictionaryValue json; - callback.Run(*ReplyWithJson(200, json)); + callback.Run(ReplyWithJson(200, json), nullptr); }))); EXPECT_TRUE( command_->Complete(*CreateDictionaryValue("{'status': 'Ok'}"), nullptr)); } TEST_F(DeviceRegistrationInfoUpdateCommandTest, Cancel) { - EXPECT_CALL(http_client_, - SendRequest(HttpClient::Method::kPatch, command_url_, - HttpClient::Headers{GetAuthHeader(), GetJsonHeader()}, - _, _, _)) + EXPECT_CALL( + http_client_, + SendRequest(HttpClient::Method::kPatch, command_url_, + HttpClient::Headers{GetAuthHeader(), GetJsonHeader()}, _, _)) .WillOnce(WithArgs<3, 4>(Invoke([]( const std::string& data, - const HttpClient::SuccessCallback& callback) { + const HttpClient::SendRequestCallback& callback) { EXPECT_JSON_EQ(R"({"state":"cancelled"})", *CreateDictionaryValue(data)); base::DictionaryValue json; - callback.Run(*ReplyWithJson(200, json)); + callback.Run(ReplyWithJson(200, json), nullptr); }))); EXPECT_TRUE(command_->Cancel(nullptr)); }
diff --git a/libweave/src/notification/xmpp_channel.cc b/libweave/src/notification/xmpp_channel.cc index 0d22286..c88ba45 100644 --- a/libweave/src/notification/xmpp_channel.cc +++ b/libweave/src/notification/xmpp_channel.cc
@@ -106,10 +106,12 @@ } } -void XmppChannel::OnMessageRead(size_t size) { +void XmppChannel::OnMessageRead(size_t size, ErrorPtr error) { + read_pending_ = false; + if (error) + return Restart(); std::string msg(read_socket_data_.data(), size); VLOG(2) << "Received XMPP packet: '" << msg << "'"; - read_pending_ = false; if (!size) return Restart(); @@ -286,14 +288,24 @@ LOG(INFO) << "Starting XMPP connection to " << kDefaultXmppHost << ":" << kDefaultXmppPort; - network_->OpenSslSocket( - kDefaultXmppHost, kDefaultXmppPort, - base::Bind(&XmppChannel::OnSslSocketReady, - task_ptr_factory_.GetWeakPtr()), - base::Bind(&XmppChannel::OnSslError, task_ptr_factory_.GetWeakPtr())); + network_->OpenSslSocket(kDefaultXmppHost, kDefaultXmppPort, + base::Bind(&XmppChannel::OnSslSocketReady, + task_ptr_factory_.GetWeakPtr())); } -void XmppChannel::OnSslSocketReady(std::unique_ptr<Stream> stream) { +void XmppChannel::OnSslSocketReady(std::unique_ptr<Stream> stream, + ErrorPtr error) { + if (error) { + LOG(ERROR) << "TLS handshake failed. Restarting XMPP connection"; + backoff_entry_.InformOfRequest(false); + + LOG(INFO) << "Delaying connection to XMPP server for " + << backoff_entry_.GetTimeUntilRelease(); + return task_runner_->PostDelayedTask( + FROM_HERE, base::Bind(&XmppChannel::CreateSslSocket, + task_ptr_factory_.GetWeakPtr()), + backoff_entry_.GetTimeUntilRelease()); + } CHECK(XmppState::kConnecting == state_); backoff_entry_.InformOfRequest(true); stream_ = std::move(stream); @@ -302,18 +314,6 @@ ScheduleRegularPing(); } -void XmppChannel::OnSslError(ErrorPtr error) { - LOG(ERROR) << "TLS handshake failed. Restarting XMPP connection"; - backoff_entry_.InformOfRequest(false); - - LOG(INFO) << "Delaying connection to XMPP server for " - << backoff_entry_.GetTimeUntilRelease(); - task_runner_->PostDelayedTask( - FROM_HERE, - base::Bind(&XmppChannel::CreateSslSocket, task_ptr_factory_.GetWeakPtr()), - backoff_entry_.GetTimeUntilRelease()); -} - void XmppChannel::SendMessage(const std::string& message) { CHECK(stream_) << "No XMPP socket stream available"; if (write_pending_) { @@ -327,13 +327,13 @@ write_pending_ = true; stream_->Write( write_socket_data_.data(), write_socket_data_.size(), - base::Bind(&XmppChannel::OnMessageSent, task_ptr_factory_.GetWeakPtr()), - base::Bind(&XmppChannel::OnWriteError, task_ptr_factory_.GetWeakPtr())); + base::Bind(&XmppChannel::OnMessageSent, task_ptr_factory_.GetWeakPtr())); } -void XmppChannel::OnMessageSent() { - ErrorPtr error; +void XmppChannel::OnMessageSent(ErrorPtr error) { write_pending_ = false; + if (error) + return Restart(); if (queued_write_data_.empty()) { WaitForMessage(); } else { @@ -348,18 +348,7 @@ read_pending_ = true; stream_->Read( read_socket_data_.data(), read_socket_data_.size(), - base::Bind(&XmppChannel::OnMessageRead, task_ptr_factory_.GetWeakPtr()), - base::Bind(&XmppChannel::OnReadError, task_ptr_factory_.GetWeakPtr())); -} - -void XmppChannel::OnReadError(ErrorPtr error) { - read_pending_ = false; - Restart(); -} - -void XmppChannel::OnWriteError(ErrorPtr error) { - write_pending_ = false; - Restart(); + base::Bind(&XmppChannel::OnMessageRead, task_ptr_factory_.GetWeakPtr())); } std::string XmppChannel::GetName() const {
diff --git a/libweave/src/notification/xmpp_channel.h b/libweave/src/notification/xmpp_channel.h index 814b2a5..a40eca9 100644 --- a/libweave/src/notification/xmpp_channel.h +++ b/libweave/src/notification/xmpp_channel.h
@@ -97,15 +97,12 @@ void RestartXmppStream(); void CreateSslSocket(); - void OnSslSocketReady(std::unique_ptr<Stream> stream); - void OnSslError(ErrorPtr error); + void OnSslSocketReady(std::unique_ptr<Stream> stream, ErrorPtr error); void WaitForMessage(); - void OnMessageRead(size_t size); - void OnMessageSent(); - void OnReadError(ErrorPtr error); - void OnWriteError(ErrorPtr error); + void OnMessageRead(size_t size, ErrorPtr error); + void OnMessageSent(ErrorPtr error); void Restart(); void CloseStream();
diff --git a/libweave/src/notification/xmpp_channel_unittest.cc b/libweave/src/notification/xmpp_channel_unittest.cc index c6e0be1..6c334dd 100644 --- a/libweave/src/notification/xmpp_channel_unittest.cc +++ b/libweave/src/notification/xmpp_channel_unittest.cc
@@ -88,8 +88,9 @@ stream_{new test::FakeStream{task_runner_}}, fake_stream_{stream_.get()} {} - void Connect(const base::Callback<void(std::unique_ptr<Stream>)>& callback) { - callback.Run(std::move(stream_)); + void Connect(const base::Callback<void(std::unique_ptr<Stream>, + ErrorPtr error)>& callback) { + callback.Run(std::move(stream_), nullptr); } XmppState state() const { return state_; } @@ -121,7 +122,7 @@ class XmppChannelTest : public ::testing::Test { protected: XmppChannelTest() { - EXPECT_CALL(network_, OpenSslSocket("talk.google.com", 5223, _, _)) + EXPECT_CALL(network_, OpenSslSocket("talk.google.com", 5223, _)) .WillOnce( WithArgs<2>(Invoke(&xmpp_client_, &FakeXmppChannel::Connect))); }
diff --git a/libweave/src/privet/cloud_delegate.cc b/libweave/src/privet/cloud_delegate.cc index 3df5b86..81b7e33 100644 --- a/libweave/src/privet/cloud_delegate.cc +++ b/libweave/src/privet/cloud_delegate.cc
@@ -154,8 +154,7 @@ void AddCommand(const base::DictionaryValue& command, const UserInfo& user_info, - const CommandSuccessCallback& success_callback, - const ErrorCallback& error_callback) override { + const CommandDoneCallback& callback) override { CHECK(user_info.scope() != AuthScope::kNone); CHECK_NE(user_info.user_id(), 0u); @@ -166,44 +165,41 @@ Error::AddToPrintf(&error, FROM_HERE, errors::kDomain, errors::kInvalidParams, "Invalid role: '%s'", str_scope.c_str()); - return error_callback.Run(std::move(error)); + return callback.Run({}, std::move(error)); } std::string id; if (!command_manager_->AddCommand(command, role, &id, &error)) - return error_callback.Run(std::move(error)); + return callback.Run({}, std::move(error)); command_owners_[id] = user_info.user_id(); - success_callback.Run(*command_manager_->FindCommand(id)->ToJson()); + callback.Run(*command_manager_->FindCommand(id)->ToJson(), nullptr); } void GetCommand(const std::string& id, const UserInfo& user_info, - const CommandSuccessCallback& success_callback, - const ErrorCallback& error_callback) override { + const CommandDoneCallback& callback) override { CHECK(user_info.scope() != AuthScope::kNone); ErrorPtr error; auto command = GetCommandInternal(id, user_info, &error); if (!command) - return error_callback.Run(std::move(error)); - success_callback.Run(*command->ToJson()); + return callback.Run({}, std::move(error)); + callback.Run(*command->ToJson(), nullptr); } void CancelCommand(const std::string& id, const UserInfo& user_info, - const CommandSuccessCallback& success_callback, - const ErrorCallback& error_callback) override { + const CommandDoneCallback& callback) override { CHECK(user_info.scope() != AuthScope::kNone); ErrorPtr error; auto command = GetCommandInternal(id, user_info, &error); if (!command || !command->Cancel(&error)) - return error_callback.Run(std::move(error)); - success_callback.Run(*command->ToJson()); + return callback.Run({}, std::move(error)); + callback.Run(*command->ToJson(), nullptr); } void ListCommands(const UserInfo& user_info, - const CommandSuccessCallback& success_callback, - const ErrorCallback& error_callback) override { + const CommandDoneCallback& callback) override { CHECK(user_info.scope() != AuthScope::kNone); base::ListValue list_value; @@ -218,7 +214,7 @@ base::DictionaryValue commands_json; commands_json.Set("commands", list_value.DeepCopy()); - success_callback.Run(commands_json); + callback.Run(commands_json, nullptr); } private: @@ -285,29 +281,27 @@ } device_->RegisterDevice( - ticket_id, base::Bind(&CloudDelegateImpl::RegisterDeviceSuccess, - setup_weak_factory_.GetWeakPtr()), - base::Bind(&CloudDelegateImpl::RegisterDeviceError, + ticket_id, + base::Bind(&CloudDelegateImpl::RegisterDeviceDone, setup_weak_factory_.GetWeakPtr(), ticket_id, deadline)); } - void RegisterDeviceSuccess() { + void RegisterDeviceDone(const std::string& ticket_id, + const base::Time& deadline, + ErrorPtr error) { + if (error) { + // Registration failed. Retry with backoff. + backoff_entry_.InformOfRequest(false); + return task_runner_->PostDelayedTask( + FROM_HERE, + base::Bind(&CloudDelegateImpl::CallManagerRegisterDevice, + setup_weak_factory_.GetWeakPtr(), ticket_id, deadline), + backoff_entry_.GetTimeUntilRelease()); + } backoff_entry_.InformOfRequest(true); setup_state_ = SetupState(SetupState::kSuccess); } - void RegisterDeviceError(const std::string& ticket_id, - const base::Time& deadline, - ErrorPtr error) { - // Registration failed. Retry with backoff. - backoff_entry_.InformOfRequest(false); - task_runner_->PostDelayedTask( - FROM_HERE, - base::Bind(&CloudDelegateImpl::CallManagerRegisterDevice, - setup_weak_factory_.GetWeakPtr(), ticket_id, deadline), - backoff_entry_.GetTimeUntilRelease()); - } - CommandInstance* GetCommandInternal(const std::string& command_id, const UserInfo& user_info, ErrorPtr* error) const {
diff --git a/libweave/src/privet/cloud_delegate.h b/libweave/src/privet/cloud_delegate.h index 74456d3..05ba8a6 100644 --- a/libweave/src/privet/cloud_delegate.h +++ b/libweave/src/privet/cloud_delegate.h
@@ -39,8 +39,9 @@ CloudDelegate(); virtual ~CloudDelegate(); - using CommandSuccessCallback = - base::Callback<void(const base::DictionaryValue& commands)>; + using CommandDoneCallback = + base::Callback<void(const base::DictionaryValue& commands, + ErrorPtr error)>; class Observer { public: @@ -107,25 +108,21 @@ // Adds command created from the given JSON representation. virtual void AddCommand(const base::DictionaryValue& command, const UserInfo& user_info, - const CommandSuccessCallback& success_callback, - const ErrorCallback& error_callback) = 0; + const CommandDoneCallback& callback) = 0; // Returns command with the given ID. virtual void GetCommand(const std::string& id, const UserInfo& user_info, - const CommandSuccessCallback& success_callback, - const ErrorCallback& error_callback) = 0; + const CommandDoneCallback& callback) = 0; // Cancels command with the given ID. virtual void CancelCommand(const std::string& id, const UserInfo& user_info, - const CommandSuccessCallback& success_callback, - const ErrorCallback& error_callback) = 0; + const CommandDoneCallback& callback) = 0; // Lists commands. virtual void ListCommands(const UserInfo& user_info, - const CommandSuccessCallback& success_callback, - const ErrorCallback& error_callback) = 0; + const CommandDoneCallback& callback) = 0; void AddObserver(Observer* observer) { observer_list_.AddObserver(observer); } void RemoveObserver(Observer* observer) {
diff --git a/libweave/src/privet/mock_delegates.h b/libweave/src/privet/mock_delegates.h index 48227ae..f16b526 100644 --- a/libweave/src/privet/mock_delegates.h +++ b/libweave/src/privet/mock_delegates.h
@@ -154,25 +154,19 @@ MOCK_CONST_METHOD0(GetCloudId, std::string()); MOCK_CONST_METHOD0(GetState, const base::DictionaryValue&()); MOCK_CONST_METHOD0(GetCommandDef, const base::DictionaryValue&()); - MOCK_METHOD4(AddCommand, + MOCK_METHOD3(AddCommand, void(const base::DictionaryValue&, const UserInfo&, - const CommandSuccessCallback&, - const ErrorCallback&)); - MOCK_METHOD4(GetCommand, + const CommandDoneCallback&)); + MOCK_METHOD3(GetCommand, void(const std::string&, const UserInfo&, - const CommandSuccessCallback&, - const ErrorCallback&)); - MOCK_METHOD4(CancelCommand, + const CommandDoneCallback&)); + MOCK_METHOD3(CancelCommand, void(const std::string&, const UserInfo&, - const CommandSuccessCallback&, - const ErrorCallback&)); - MOCK_METHOD3(ListCommands, - void(const UserInfo&, - const CommandSuccessCallback&, - const ErrorCallback&)); + const CommandDoneCallback&)); + MOCK_METHOD2(ListCommands, void(const UserInfo&, const CommandDoneCallback&)); MockCloudDelegate() { EXPECT_CALL(*this, GetDeviceId()).WillRepeatedly(Return("TestId"));
diff --git a/libweave/src/privet/privet_handler.cc b/libweave/src/privet/privet_handler.cc index 702104d..ad1d3d0 100644 --- a/libweave/src/privet/privet_handler.cc +++ b/libweave/src/privet/privet_handler.cc
@@ -198,23 +198,20 @@ } void OnCommandRequestSucceeded(const PrivetHandler::RequestCallback& callback, - const base::DictionaryValue& output) { - callback.Run(http::kOk, output); -} + const base::DictionaryValue& output, + ErrorPtr error) { + if (!error) + return callback.Run(http::kOk, output); -void OnCommandRequestFailed(const PrivetHandler::RequestCallback& callback, - ErrorPtr error) { if (error->HasError("gcd", "unknown_command")) { - ErrorPtr new_error = error->Clone(); - Error::AddTo(&new_error, FROM_HERE, errors::kDomain, errors::kNotFound, + Error::AddTo(&error, FROM_HERE, errors::kDomain, errors::kNotFound, "Unknown command ID"); - return ReturnError(*new_error, callback); + return ReturnError(*error, callback); } if (error->HasError("gcd", "access_denied")) { - ErrorPtr new_error = error->Clone(); - Error::AddTo(&new_error, FROM_HERE, errors::kDomain, errors::kAccessDenied, + Error::AddTo(&error, FROM_HERE, errors::kDomain, errors::kAccessDenied, error->GetMessage()); - return ReturnError(*new_error, callback); + return ReturnError(*error, callback); } return ReturnError(*error, callback); } @@ -768,8 +765,7 @@ const UserInfo& user_info, const RequestCallback& callback) { cloud_->AddCommand(input, user_info, - base::Bind(&OnCommandRequestSucceeded, callback), - base::Bind(&OnCommandRequestFailed, callback)); + base::Bind(&OnCommandRequestSucceeded, callback)); } void PrivetHandler::HandleCommandsStatus(const base::DictionaryValue& input, @@ -784,16 +780,14 @@ return ReturnError(*error, callback); } cloud_->GetCommand(id, user_info, - base::Bind(&OnCommandRequestSucceeded, callback), - base::Bind(&OnCommandRequestFailed, callback)); + base::Bind(&OnCommandRequestSucceeded, callback)); } void PrivetHandler::HandleCommandsList(const base::DictionaryValue& input, const UserInfo& user_info, const RequestCallback& callback) { cloud_->ListCommands(user_info, - base::Bind(&OnCommandRequestSucceeded, callback), - base::Bind(&OnCommandRequestFailed, callback)); + base::Bind(&OnCommandRequestSucceeded, callback)); } void PrivetHandler::HandleCommandsCancel(const base::DictionaryValue& input, @@ -808,8 +802,7 @@ return ReturnError(*error, callback); } cloud_->CancelCommand(id, user_info, - base::Bind(&OnCommandRequestSucceeded, callback), - base::Bind(&OnCommandRequestFailed, callback)); + base::Bind(&OnCommandRequestSucceeded, callback)); } } // namespace privet
diff --git a/libweave/src/privet/privet_handler_unittest.cc b/libweave/src/privet/privet_handler_unittest.cc index 42dd956..dba394a 100644 --- a/libweave/src/privet/privet_handler_unittest.cc +++ b/libweave/src/privet/privet_handler_unittest.cc
@@ -649,8 +649,11 @@ base::DictionaryValue command; LoadTestJson(kInput, &command); LoadTestJson("{'id':'5'}", &command); - EXPECT_CALL(cloud_, AddCommand(_, _, _, _)) - .WillOnce(RunCallback<2, const base::DictionaryValue&>(command)); + EXPECT_CALL(cloud_, AddCommand(_, _, _)) + .WillOnce(WithArgs<2>(Invoke( + [&command](const CloudDelegate::CommandDoneCallback& callback) { + callback.Run(command, nullptr); + }))); EXPECT_PRED2(IsEqualJson, "{'name':'test', 'id':'5'}", HandleRequest("/privet/v3/commands/execute", kInput)); @@ -661,18 +664,22 @@ base::DictionaryValue command; LoadTestJson(kInput, &command); LoadTestJson("{'name':'test'}", &command); - EXPECT_CALL(cloud_, GetCommand(_, _, _, _)) - .WillOnce(RunCallback<2, const base::DictionaryValue&>(command)); + EXPECT_CALL(cloud_, GetCommand(_, _, _)) + .WillOnce(WithArgs<2>(Invoke( + [&command](const CloudDelegate::CommandDoneCallback& callback) { + callback.Run(command, nullptr); + }))); EXPECT_PRED2(IsEqualJson, "{'name':'test', 'id':'5'}", HandleRequest("/privet/v3/commands/status", kInput)); ErrorPtr error; Error::AddTo(&error, FROM_HERE, errors::kDomain, "notFound", ""); - EXPECT_CALL(cloud_, GetCommand(_, _, _, _)) - .WillOnce(WithArgs<3>(Invoke([&error](const ErrorCallback& callback) { - callback.Run(std::move(error)); - }))); + EXPECT_CALL(cloud_, GetCommand(_, _, _)) + .WillOnce(WithArgs<2>( + Invoke([&error](const CloudDelegate::CommandDoneCallback& callback) { + callback.Run({}, std::move(error)); + }))); EXPECT_PRED2(IsEqualError, CodeWithReason(404, "notFound"), HandleRequest("/privet/v3/commands/status", "{'id': '15'}")); @@ -682,18 +689,22 @@ const char kExpected[] = "{'id': '5', 'name':'test', 'state':'cancelled'}"; base::DictionaryValue command; LoadTestJson(kExpected, &command); - EXPECT_CALL(cloud_, CancelCommand(_, _, _, _)) - .WillOnce(RunCallback<2, const base::DictionaryValue&>(command)); + EXPECT_CALL(cloud_, CancelCommand(_, _, _)) + .WillOnce(WithArgs<2>(Invoke( + [&command](const CloudDelegate::CommandDoneCallback& callback) { + callback.Run(command, nullptr); + }))); EXPECT_PRED2(IsEqualJson, kExpected, HandleRequest("/privet/v3/commands/cancel", "{'id': '8'}")); ErrorPtr error; Error::AddTo(&error, FROM_HERE, errors::kDomain, "notFound", ""); - EXPECT_CALL(cloud_, CancelCommand(_, _, _, _)) - .WillOnce(WithArgs<3>(Invoke([&error](const ErrorCallback& callback) { - callback.Run(std::move(error)); - }))); + EXPECT_CALL(cloud_, CancelCommand(_, _, _)) + .WillOnce(WithArgs<2>( + Invoke([&error](const CloudDelegate::CommandDoneCallback& callback) { + callback.Run({}, std::move(error)); + }))); EXPECT_PRED2(IsEqualError, CodeWithReason(404, "notFound"), HandleRequest("/privet/v3/commands/cancel", "{'id': '11'}")); @@ -709,8 +720,11 @@ base::DictionaryValue commands; LoadTestJson(kExpected, &commands); - EXPECT_CALL(cloud_, ListCommands(_, _, _)) - .WillOnce(RunCallback<1, const base::DictionaryValue&>(commands)); + EXPECT_CALL(cloud_, ListCommands(_, _)) + .WillOnce(WithArgs<1>(Invoke( + [&commands](const CloudDelegate::CommandDoneCallback& callback) { + callback.Run(commands, nullptr); + }))); EXPECT_PRED2(IsEqualJson, kExpected, HandleRequest("/privet/v3/commands/list", "{}"));
diff --git a/libweave/src/privet/wifi_bootstrap_manager.cc b/libweave/src/privet/wifi_bootstrap_manager.cc index b4b0153..1d2d813 100644 --- a/libweave/src/privet/wifi_bootstrap_manager.cc +++ b/libweave/src/privet/wifi_bootstrap_manager.cc
@@ -96,21 +96,12 @@ << ", pass=" << passphrase << ")."; UpdateState(State::kConnecting); task_runner_->PostDelayedTask( - FROM_HERE, base::Bind(&WifiBootstrapManager::OnConnectError, - tasks_weak_factory_.GetWeakPtr(), nullptr), + FROM_HERE, base::Bind(&WifiBootstrapManager::OnConnectTimeout, + tasks_weak_factory_.GetWeakPtr()), base::TimeDelta::FromMinutes(3)); wifi_->Connect(ssid, passphrase, - base::Bind(&WifiBootstrapManager::OnConnectSuccess, - tasks_weak_factory_.GetWeakPtr(), ssid), - base::Bind(&WifiBootstrapManager::OnConnectError, - tasks_weak_factory_.GetWeakPtr())); -} - -void WifiBootstrapManager::OnConnectError(ErrorPtr error) { - Error::AddTo(&error, FROM_HERE, errors::kDomain, errors::kInvalidState, - "Failed to connect to provided network"); - setup_state_ = SetupState{std::move(error)}; - StartBootstrapping(); + base::Bind(&WifiBootstrapManager::OnConnectDone, + tasks_weak_factory_.GetWeakPtr(), ssid)); } void WifiBootstrapManager::EndConnecting() {} @@ -203,7 +194,14 @@ return {WifiType::kWifi24}; } -void WifiBootstrapManager::OnConnectSuccess(const std::string& ssid) { +void WifiBootstrapManager::OnConnectDone(const std::string& ssid, + ErrorPtr error) { + if (error) { + Error::AddTo(&error, FROM_HERE, errors::kDomain, errors::kInvalidState, + "Failed to connect to provided network"); + setup_state_ = SetupState{std::move(error)}; + return StartBootstrapping(); + } VLOG(1) << "Wifi was connected successfully"; Config::Transaction change{config_}; change.set_last_configured_ssid(ssid); @@ -212,6 +210,14 @@ StartMonitoring(); } +void WifiBootstrapManager::OnConnectTimeout() { + ErrorPtr error; + Error::AddTo(&error, FROM_HERE, errors::kDomain, errors::kInvalidState, + "Timeout connecting to provided network"); + setup_state_ = SetupState{std::move(error)}; + return StartBootstrapping(); +} + void WifiBootstrapManager::OnBootstrapTimeout() { VLOG(1) << "Bootstrapping has timed out."; StartMonitoring();
diff --git a/libweave/src/privet/wifi_bootstrap_manager.h b/libweave/src/privet/wifi_bootstrap_manager.h index 390af31..c0a1c24 100644 --- a/libweave/src/privet/wifi_bootstrap_manager.h +++ b/libweave/src/privet/wifi_bootstrap_manager.h
@@ -87,8 +87,8 @@ // to return to monitoring mode periodically in case our connectivity issues // were temporary. void OnBootstrapTimeout(); - void OnConnectSuccess(const std::string& ssid); - void OnConnectError(ErrorPtr error); + void OnConnectDone(const std::string& ssid, ErrorPtr error); + void OnConnectTimeout(); void OnConnectivityChange(); void OnMonitorTimeout(); void UpdateConnectionState();
diff --git a/libweave/src/streams.cc b/libweave/src/streams.cc index 1dc0355..0516f66 100644 --- a/libweave/src/streams.cc +++ b/libweave/src/streams.cc
@@ -19,52 +19,54 @@ void MemoryStream::Read(void* buffer, size_t size_to_read, - const ReadSuccessCallback& success_callback, - const ErrorCallback& error_callback) { + const ReadCallback& callback) { CHECK_LE(read_position_, data_.size()); size_t size_read = std::min(size_to_read, data_.size() - read_position_); if (size_read > 0) memcpy(buffer, data_.data() + read_position_, size_read); read_position_ += size_read; task_runner_->PostDelayedTask(FROM_HERE, - base::Bind(success_callback, size_read), {}); + base::Bind(callback, size_read, nullptr), {}); } void MemoryStream::Write(const void* buffer, size_t size_to_write, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) { + const WriteCallback& callback) { data_.insert(data_.end(), static_cast<const char*>(buffer), static_cast<const char*>(buffer) + size_to_write); - task_runner_->PostDelayedTask(FROM_HERE, success_callback, {}); + task_runner_->PostDelayedTask(FROM_HERE, base::Bind(callback, nullptr), {}); } StreamCopier::StreamCopier(InputStream* source, OutputStream* destination) : source_{source}, destination_{destination}, buffer_(4096) {} -void StreamCopier::Copy( - const InputStream::ReadSuccessCallback& success_callback, - const ErrorCallback& error_callback) { - source_->Read( - buffer_.data(), buffer_.size(), - base::Bind(&StreamCopier::OnSuccessRead, weak_ptr_factory_.GetWeakPtr(), - success_callback, error_callback), - error_callback); +void StreamCopier::Copy(const InputStream::ReadCallback& callback) { + source_->Read(buffer_.data(), buffer_.size(), + base::Bind(&StreamCopier::OnReadDone, + weak_ptr_factory_.GetWeakPtr(), callback)); } -void StreamCopier::OnSuccessRead( - const InputStream::ReadSuccessCallback& success_callback, - const ErrorCallback& error_callback, - size_t size) { +void StreamCopier::OnReadDone(const InputStream::ReadCallback& callback, + size_t size, + ErrorPtr error) { + if (error) + return callback.Run(0, std::move(error)); + size_done_ += size; if (size) { return destination_->Write( buffer_.data(), size, - base::Bind(&StreamCopier::Copy, weak_ptr_factory_.GetWeakPtr(), - success_callback, error_callback), - error_callback); + base::Bind(&StreamCopier::OnWriteDone, weak_ptr_factory_.GetWeakPtr(), + callback)); } - success_callback.Run(size_done_); + callback.Run(size_done_, nullptr); +} + +void StreamCopier::OnWriteDone(const InputStream::ReadCallback& callback, + ErrorPtr error) { + if (error) + return callback.Run(size_done_, std::move(error)); + Copy(callback); } } // namespace weave
diff --git a/libweave/src/streams.h b/libweave/src/streams.h index 0a21737..6d2a1d0 100644 --- a/libweave/src/streams.h +++ b/libweave/src/streams.h
@@ -21,13 +21,11 @@ void Read(void* buffer, size_t size_to_read, - const ReadSuccessCallback& success_callback, - const ErrorCallback& error_callback) override; + const ReadCallback& callback) override; void Write(const void* buffer, size_t size_to_write, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) override; + const WriteCallback& callback) override; const std::vector<uint8_t>& GetData() const { return data_; } @@ -41,13 +39,13 @@ public: StreamCopier(InputStream* source, OutputStream* destination); - void Copy(const InputStream::ReadSuccessCallback& success_callback, - const ErrorCallback& error_callback); + void Copy(const InputStream::ReadCallback& callback); private: - void OnSuccessRead(const InputStream::ReadSuccessCallback& success_callback, - const ErrorCallback& error_callback, - size_t size); + void OnWriteDone(const InputStream::ReadCallback& callback, ErrorPtr error); + void OnReadDone(const InputStream::ReadCallback& callback, + size_t size, + ErrorPtr error); InputStream* source_{nullptr}; OutputStream* destination_{nullptr};
diff --git a/libweave/src/streams_unittest.cc b/libweave/src/streams_unittest.cc index 3cef6f0..b669328 100644 --- a/libweave/src/streams_unittest.cc +++ b/libweave/src/streams_unittest.cc
@@ -23,13 +23,14 @@ bool done = false; - auto on_success = base::Bind([&test_data, &done, &destination](size_t size) { - done = true; - EXPECT_EQ(test_data, destination.GetData()); - }); - auto on_error = base::Bind([](ErrorPtr error) { ADD_FAILURE(); }); + auto callback = base::Bind( + [&test_data, &done, &destination](size_t size, ErrorPtr error) { + EXPECT_FALSE(error); + done = true; + EXPECT_EQ(test_data, destination.GetData()); + }); StreamCopier copier{&source, &destination}; - copier.Copy(on_success, on_error); + copier.Copy(callback); task_runner.Run(test_data.size()); EXPECT_TRUE(done);
diff --git a/libweave/src/test/fake_stream.cc b/libweave/src/test/fake_stream.cc index 786aa01..7c86c4a 100644 --- a/libweave/src/test/fake_stream.cc +++ b/libweave/src/test/fake_stream.cc
@@ -30,33 +30,30 @@ void FakeStream::Read(void* buffer, size_t size_to_read, - const ReadSuccessCallback& success_callback, - const ErrorCallback& error_callback) { + const ReadCallback& callback) { if (read_data_.empty()) { task_runner_->PostDelayedTask( - FROM_HERE, - base::Bind(&FakeStream::Read, base::Unretained(this), buffer, - size_to_read, success_callback, error_callback), + FROM_HERE, base::Bind(&FakeStream::Read, base::Unretained(this), buffer, + size_to_read, callback), base::TimeDelta::FromSeconds(0)); return; } size_t size = std::min(size_to_read, read_data_.size()); memcpy(buffer, read_data_.data(), size); read_data_ = read_data_.substr(size); - task_runner_->PostDelayedTask(FROM_HERE, base::Bind(success_callback, size), + task_runner_->PostDelayedTask(FROM_HERE, base::Bind(callback, size, nullptr), base::TimeDelta::FromSeconds(0)); } void FakeStream::Write(const void* buffer, size_t size_to_write, - const SuccessCallback& success_callback, - const ErrorCallback& error_callback) { + const WriteCallback& callback) { size_t size = std::min(size_to_write, write_data_.size()); EXPECT_EQ( write_data_.substr(0, size), std::string(reinterpret_cast<const char*>(buffer), size_to_write)); write_data_ = write_data_.substr(size); - task_runner_->PostDelayedTask(FROM_HERE, success_callback, + task_runner_->PostDelayedTask(FROM_HERE, base::Bind(callback, nullptr), base::TimeDelta::FromSeconds(0)); }
diff --git a/libweave/src/weave_unittest.cc b/libweave/src/weave_unittest.cc index 8d87229..74e8c41 100644 --- a/libweave/src/weave_unittest.cc +++ b/libweave/src/weave_unittest.cc
@@ -140,9 +140,9 @@ void ExpectRequest(HttpClient::Method method, const std::string& url, const std::string& json_response) { - EXPECT_CALL(http_client_, SendRequest(method, url, _, _, _, _)) + EXPECT_CALL(http_client_, SendRequest(method, url, _, _, _)) .WillOnce(WithArgs<4>(Invoke([json_response]( - const HttpClient::SuccessCallback& callback) { + const HttpClient::SendRequestCallback& callback) { std::unique_ptr<provider::test::MockHttpClientResponse> response{ new StrictMock<provider::test::MockHttpClientResponse>}; EXPECT_CALL(*response, GetStatusCode()) @@ -154,7 +154,7 @@ EXPECT_CALL(*response, GetData()) .Times(AtLeast(1)) .WillRepeatedly(Return(json_response)); - callback.Run(*response); + callback.Run(std::move(response), nullptr); }))); } @@ -306,7 +306,7 @@ } TEST_F(WeaveBasicTest, Register) { - EXPECT_CALL(network_, OpenSslSocket(_, _, _, _)).WillRepeatedly(Return()); + EXPECT_CALL(network_, OpenSslSocket(_, _, _)).WillRepeatedly(Return()); StartDevice(); auto draft = CreateDictionaryValue(kDeviceResource); @@ -333,12 +333,12 @@ InitDnsSdPublishing(true, "DB"); bool done = false; - device_->Register("TICKET_ID", base::Bind([this, &done]() { + device_->Register("TICKET_ID", base::Bind([this, &done](ErrorPtr error) { + EXPECT_FALSE(error); done = true; task_runner_.Break(); EXPECT_EQ("CLOUD_ID", device_->GetSettings().cloud_id); - }), - base::Bind([](ErrorPtr error) { ADD_FAILURE(); })); + })); task_runner_.Run(); EXPECT_TRUE(done); }