Use single callback for replies to async operations
Single callback simplifies copying of callbacks and makes
control flow more obvious.
BUG:24267885
Change-Id: I489e7158e2bb1adf8c9c3966a0859fa024a57db2
Reviewed-on: https://weave-review.googlesource.com/1302
Reviewed-by: Vitaly Buka <vitalybuka@google.com>
diff --git a/libweave/examples/ubuntu/curl_http_client.cc b/libweave/examples/ubuntu/curl_http_client.cc
index c058a9b..653cbb3 100644
--- a/libweave/examples/ubuntu/curl_http_client.cc
+++ b/libweave/examples/ubuntu/curl_http_client.cc
@@ -35,18 +35,11 @@
CurlHttpClient::CurlHttpClient(provider::TaskRunner* task_runner)
: task_runner_{task_runner} {}
-void CurlHttpClient::PostError(const ErrorCallback& error_callback,
- ErrorPtr error) {
- task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(error_callback, base::Passed(&error)), {});
-}
-
void CurlHttpClient::SendRequest(Method method,
const std::string& url,
const Headers& headers,
const std::string& data,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+ const SendRequestCallback& callback) {
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl{curl_easy_init(),
&curl_easy_cleanup};
CHECK(curl);
@@ -96,7 +89,8 @@
if (res != CURLE_OK) {
Error::AddTo(&error, FROM_HERE, "curl", "curl_easy_perform_error",
curl_easy_strerror(res));
- return PostError(error_callback, std::move(error));
+ return task_runner_->PostDelayedTask(
+ FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
}
const std::string kContentType = "\r\nContent-Type:";
@@ -104,7 +98,8 @@
if (pos == std::string::npos) {
Error::AddTo(&error, FROM_HERE, "curl", "no_content_header",
"Content-Type header is missing");
- return PostError(error_callback, std::move(error));
+ return task_runner_->PostDelayedTask(
+ FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
}
pos += kContentType.size();
auto pos_end = response->content_type.find("\r\n", pos);
@@ -118,15 +113,7 @@
&response->status));
task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(&CurlHttpClient::RunSuccessCallback,
- weak_ptr_factory_.GetWeakPtr(), success_callback,
- base::Passed(&response)),
- {});
-}
-
-void CurlHttpClient::RunSuccessCallback(const SuccessCallback& success_callback,
- std::unique_ptr<Response> response) {
- success_callback.Run(*response);
+ FROM_HERE, base::Bind(callback, base::Passed(&response), nullptr), {});
}
} // namespace examples
diff --git a/libweave/examples/ubuntu/curl_http_client.h b/libweave/examples/ubuntu/curl_http_client.h
index a5090ff..6bbb3f4 100644
--- a/libweave/examples/ubuntu/curl_http_client.h
+++ b/libweave/examples/ubuntu/curl_http_client.h
@@ -28,14 +28,9 @@
const std::string& url,
const Headers& headers,
const std::string& data,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) override;
+ const SendRequestCallback& callback) override;
private:
- void PostError(const ErrorCallback& error_callback, ErrorPtr error);
-
- void RunSuccessCallback(const SuccessCallback& success_callback,
- std::unique_ptr<Response> response);
provider::TaskRunner* task_runner_{nullptr};
base::WeakPtrFactory<CurlHttpClient> weak_ptr_factory_{this};
diff --git a/libweave/examples/ubuntu/event_http_client.cc b/libweave/examples/ubuntu/event_http_client.cc
index 8341a07..637ae2f 100644
--- a/libweave/examples/ubuntu/event_http_client.cc
+++ b/libweave/examples/ubuntu/event_http_client.cc
@@ -60,8 +60,7 @@
TaskRunner* task_runner_;
std::unique_ptr<evhttp_uri, EventDeleter> http_uri_;
std::unique_ptr<evhttp_connection, EventDeleter> evcon_;
- HttpClient::SuccessCallback success_callback_;
- ErrorCallback error_callback_;
+ HttpClient::SendRequestCallback callback_;
};
void RequestDoneCallback(evhttp_request* req, void* ctx) {
@@ -74,7 +73,7 @@
"request failed: %s",
evutil_socket_error_to_string(err));
state->task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(state->error_callback_, base::Passed(&error)),
+ FROM_HERE, base::Bind(state->callback_, nullptr, base::Passed(&error)),
{});
return;
}
@@ -86,7 +85,8 @@
auto n = evbuffer_remove(buffer, &response->data[0], length);
CHECK_EQ(n, int(length));
state->task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(state->success_callback_, *response), {});
+ FROM_HERE, base::Bind(state->callback_, base::Passed(&response), nullptr),
+ {});
}
} // namespace
@@ -98,8 +98,7 @@
const std::string& url,
const Headers& headers,
const std::string& data,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+ const SendRequestCallback& callback) {
evhttp_cmd_type method_id;
CHECK(weave::StringToEnum(weave::EnumToString(method), &method_id));
std::unique_ptr<evhttp_uri, EventDeleter> http_uri{
@@ -129,7 +128,7 @@
std::unique_ptr<evhttp_request, EventDeleter> req{evhttp_request_new(
&RequestDoneCallback,
new EventRequestState{task_runner_, std::move(http_uri), std::move(conn),
- success_callback, error_callback})};
+ callback})};
CHECK(req);
auto output_headers = evhttp_request_get_output_headers(req.get());
evhttp_add_header(output_headers, "Host", host);
@@ -150,7 +149,7 @@
"request failed: %s %s", EnumToString(method).c_str(),
url.c_str());
task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(error_callback, base::Passed(&error)), {});
+ FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
}
} // namespace examples
diff --git a/libweave/examples/ubuntu/event_http_client.h b/libweave/examples/ubuntu/event_http_client.h
index 457e550..e564e8a 100644
--- a/libweave/examples/ubuntu/event_http_client.h
+++ b/libweave/examples/ubuntu/event_http_client.h
@@ -24,8 +24,7 @@
const std::string& url,
const Headers& headers,
const std::string& data,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) override;
+ const SendRequestCallback& callback) override;
private:
EventTaskRunner* task_runner_{nullptr};
diff --git a/libweave/examples/ubuntu/main.cc b/libweave/examples/ubuntu/main.cc
index e762791..6324411 100644
--- a/libweave/examples/ubuntu/main.cc
+++ b/libweave/examples/ubuntu/main.cc
@@ -221,12 +221,11 @@
base::WeakPtrFactory<CommandHandler> weak_ptr_factory_{this};
};
-void RegisterDeviceSuccess(weave::Device* device) {
- LOG(INFO) << "Device registered: " << device->GetSettings().cloud_id;
-}
-
-void RegisterDeviceError(weave::ErrorPtr error) {
- LOG(ERROR) << "Fail to register device: " << error->GetMessage();
+void OnRegisterDeviceDone(weave::Device* device, weave::ErrorPtr error) {
+ if (error)
+ LOG(ERROR) << "Fail to register device: " << error->GetMessage();
+ else
+ LOG(INFO) << "Device registered: " << device->GetSettings().cloud_id;
}
} // namespace
@@ -280,8 +279,7 @@
if (!registration_ticket.empty()) {
device->Register(registration_ticket,
- base::Bind(&RegisterDeviceSuccess, device.get()),
- base::Bind(&RegisterDeviceError));
+ base::Bind(&OnRegisterDeviceDone, device.get()));
}
CommandHandler handler(device.get(), &task_runner);
diff --git a/libweave/examples/ubuntu/netlink_network.cc b/libweave/examples/ubuntu/netlink_network.cc
index 24d0996..60807a4 100644
--- a/libweave/examples/ubuntu/netlink_network.cc
+++ b/libweave/examples/ubuntu/netlink_network.cc
@@ -91,23 +91,22 @@
return network_state_;
}
-void NetlinkNetworkImpl::OpenSslSocket(
- const std::string& host,
- uint16_t port,
- const OpenSslSocketSuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+void NetlinkNetworkImpl::OpenSslSocket(const std::string& host,
+ uint16_t port,
+ const OpenSslSocketCallback& callback) {
// Connect to SSL port instead of upgrading to TLS.
std::unique_ptr<SSLStream> tls_stream{new SSLStream{task_runner_}};
if (tls_stream->Init(host, port)) {
task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(success_callback, base::Passed(&tls_stream)), {});
+ FROM_HERE, base::Bind(callback, base::Passed(&tls_stream), nullptr),
+ {});
} else {
ErrorPtr error;
Error::AddTo(&error, FROM_HERE, "tls", "tls_init_failed",
"Failed to initialize TLS stream.");
task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(error_callback, base::Passed(&error)), {});
+ FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
}
}
diff --git a/libweave/examples/ubuntu/netlink_network.h b/libweave/examples/ubuntu/netlink_network.h
index 3f5f51a..af8bf13 100644
--- a/libweave/examples/ubuntu/netlink_network.h
+++ b/libweave/examples/ubuntu/netlink_network.h
@@ -29,8 +29,7 @@
State GetConnectionState() const override;
void OpenSslSocket(const std::string& host,
uint16_t port,
- const OpenSslSocketSuccessCallback& success_callback,
- const ErrorCallback& error_callback) override;
+ const OpenSslSocketCallback& callback) override;
private:
class Deleter {
diff --git a/libweave/examples/ubuntu/network_manager.cc b/libweave/examples/ubuntu/network_manager.cc
index ff54f0b..e6dd9d8 100644
--- a/libweave/examples/ubuntu/network_manager.cc
+++ b/libweave/examples/ubuntu/network_manager.cc
@@ -61,8 +61,7 @@
const std::string& passphrase,
int pid,
base::Time until,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+ const DoneCallback& callback) {
if (pid) {
int status = 0;
if (pid == waitpid(pid, &status, WNOWAIT)) {
@@ -79,7 +78,8 @@
close(sockf_d);
if (ssid == essid)
- return task_runner_->PostDelayedTask(FROM_HERE, success_callback, {});
+ return task_runner_->PostDelayedTask(FROM_HERE,
+ base::Bind(callback, nullptr), {});
pid = 0; // Try again.
}
}
@@ -94,34 +94,32 @@
Error::AddTo(&error, FROM_HERE, "wifi", "timeout",
"Timeout connecting to WiFI network.");
task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(error_callback, base::Passed(&error)), {});
+ FROM_HERE, base::Bind(callback, base::Passed(&error)), {});
return;
}
task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(&NetworkImpl::TryToConnect,
- weak_ptr_factory_.GetWeakPtr(), ssid, passphrase,
- pid, until, success_callback, error_callback),
+ FROM_HERE,
+ base::Bind(&NetworkImpl::TryToConnect, weak_ptr_factory_.GetWeakPtr(),
+ ssid, passphrase, pid, until, callback),
base::TimeDelta::FromSeconds(1));
}
void NetworkImpl::Connect(const std::string& ssid,
const std::string& passphrase,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+ const DoneCallback& callback) {
force_bootstrapping_ = false;
CHECK(!hostapd_started_);
if (hostapd_started_) {
ErrorPtr error;
Error::AddTo(&error, FROM_HERE, "wifi", "busy", "Running Access Point.");
task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(error_callback, base::Passed(&error)), {});
+ FROM_HERE, base::Bind(callback, base::Passed(&error)), {});
return;
}
TryToConnect(ssid, passphrase, 0,
- base::Time::Now() + base::TimeDelta::FromMinutes(1),
- success_callback, error_callback);
+ base::Time::Now() + base::TimeDelta::FromMinutes(1), callback);
}
void NetworkImpl::UpdateNetworkState() {
@@ -200,23 +198,22 @@
return std::system("nmcli dev | grep ^wlan0") == 0;
}
-void NetworkImpl::OpenSslSocket(
- const std::string& host,
- uint16_t port,
- const OpenSslSocketSuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+void NetworkImpl::OpenSslSocket(const std::string& host,
+ uint16_t port,
+ const OpenSslSocketCallback& callback) {
// Connect to SSL port instead of upgrading to TLS.
std::unique_ptr<SSLStream> tls_stream{new SSLStream{task_runner_}};
if (tls_stream->Init(host, port)) {
task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(success_callback, base::Passed(&tls_stream)), {});
+ FROM_HERE, base::Bind(callback, base::Passed(&tls_stream), nullptr),
+ {});
} else {
ErrorPtr error;
Error::AddTo(&error, FROM_HERE, "tls", "tls_init_failed",
"Failed to initialize TLS stream.");
task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(error_callback, base::Passed(&error)), {});
+ FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
}
}
diff --git a/libweave/examples/ubuntu/network_manager.h b/libweave/examples/ubuntu/network_manager.h
index b8e589d..164d419 100644
--- a/libweave/examples/ubuntu/network_manager.h
+++ b/libweave/examples/ubuntu/network_manager.h
@@ -35,14 +35,12 @@
State GetConnectionState() const override;
void OpenSslSocket(const std::string& host,
uint16_t port,
- const OpenSslSocketSuccessCallback& success_callback,
- const ErrorCallback& error_callback) override;
+ const OpenSslSocketCallback& callback) override;
// Wifi implementation.
void Connect(const std::string& ssid,
const std::string& passphrase,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) override;
+ const DoneCallback& callback) override;
void StartAccessPoint(const std::string& ssid) override;
void StopAccessPoint() override;
@@ -53,8 +51,7 @@
const std::string& passphrase,
int pid,
base::Time until,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback);
+ const DoneCallback& callback);
void UpdateNetworkState();
bool force_bootstrapping_{false};
diff --git a/libweave/examples/ubuntu/ssl_stream.cc b/libweave/examples/ubuntu/ssl_stream.cc
index e146d58..8b17358 100644
--- a/libweave/examples/ubuntu/ssl_stream.cc
+++ b/libweave/examples/ubuntu/ssl_stream.cc
@@ -23,14 +23,13 @@
void SSLStream::Read(void* buffer,
size_t size_to_read,
- const ReadSuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+ const ReadCallback& callback) {
int res = SSL_read(ssl_.get(), buffer, size_to_read);
if (res > 0) {
task_runner_->PostDelayedTask(
FROM_HERE,
base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(),
- base::Bind(success_callback, res)),
+ base::Bind(callback, res, nullptr)),
{});
return;
}
@@ -39,9 +38,8 @@
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
task_runner_->PostDelayedTask(
- FROM_HERE,
- base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer,
- size_to_read, success_callback, error_callback),
+ FROM_HERE, base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(),
+ buffer, size_to_read, callback),
base::TimeDelta::FromSeconds(1));
return;
}
@@ -52,15 +50,14 @@
task_runner_->PostDelayedTask(
FROM_HERE,
base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(),
- base::Bind(error_callback, base::Passed(&weave_error))),
+ base::Bind(callback, 0, base::Passed(&weave_error))),
{});
return;
}
void SSLStream::Write(const void* buffer,
size_t size_to_write,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+ const WriteCallback& callback) {
int res = SSL_write(ssl_.get(), buffer, size_to_write);
if (res > 0) {
buffer = static_cast<const char*>(buffer) + res;
@@ -69,15 +66,14 @@
task_runner_->PostDelayedTask(
FROM_HERE,
base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(),
- success_callback),
+ base::Bind(callback, nullptr)),
{});
return;
}
task_runner_->PostDelayedTask(
- FROM_HERE,
- base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
- size_to_write, success_callback, error_callback),
+ FROM_HERE, base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(),
+ buffer, size_to_write, callback),
base::TimeDelta::FromSeconds(1));
return;
@@ -87,9 +83,8 @@
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
task_runner_->PostDelayedTask(
- FROM_HERE,
- base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
- size_to_write, success_callback, error_callback),
+ FROM_HERE, base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(),
+ buffer, size_to_write, callback),
base::TimeDelta::FromSeconds(1));
return;
}
@@ -100,7 +95,7 @@
task_runner_->PostDelayedTask(
FROM_HERE,
base::Bind(&SSLStream::RunDelayedTask, weak_ptr_factory_.GetWeakPtr(),
- base::Bind(error_callback, base::Passed(&weave_error))),
+ base::Bind(callback, base::Passed(&weave_error))),
{});
return;
}
diff --git a/libweave/examples/ubuntu/ssl_stream.h b/libweave/examples/ubuntu/ssl_stream.h
index ac0d76a..859cbf9 100644
--- a/libweave/examples/ubuntu/ssl_stream.h
+++ b/libweave/examples/ubuntu/ssl_stream.h
@@ -26,13 +26,11 @@
void Read(void* buffer,
size_t size_to_read,
- const ReadSuccessCallback& success_callback,
- const ErrorCallback& error_callback) override;
+ const ReadCallback& callback) override;
void Write(const void* buffer,
size_t size_to_write,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) override;
+ const WriteCallback& callback) override;
void CancelPendingOperations() override;
diff --git a/libweave/include/weave/device.h b/libweave/include/weave/device.h
index 7fbe248..09dab91 100644
--- a/libweave/include/weave/device.h
+++ b/libweave/include/weave/device.h
@@ -121,8 +121,7 @@
// Registers the device.
// This is testing method and should not be used by applications.
virtual void Register(const std::string& ticket_id,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) = 0;
+ const DoneCallback& callback) = 0;
// Handler should display pin code to the user.
using PairingBeginCallback =
diff --git a/libweave/include/weave/error.h b/libweave/include/weave/error.h
index 83b33b4..0869b71 100644
--- a/libweave/include/weave/error.h
+++ b/libweave/include/weave/error.h
@@ -126,8 +126,13 @@
DISALLOW_COPY_AND_ASSIGN(Error);
};
-using SuccessCallback = base::Closure;
-using ErrorCallback = base::Callback<void(ErrorPtr error)>;
+// Default callback type for async operations.
+// Function having this callback as argument should call the callback exactly
+// one time.
+// Successfully completed operation should run callback with |error| set to
+// null. Failed operation should run callback with |error| containing error
+// details.
+using DoneCallback = base::Callback<void(ErrorPtr error)>;
} // namespace weave
diff --git a/libweave/include/weave/provider/http_client.h b/libweave/include/weave/provider/http_client.h
index 24b4588..3c442d9 100644
--- a/libweave/include/weave/provider/http_client.h
+++ b/libweave/include/weave/provider/http_client.h
@@ -34,14 +34,14 @@
};
using Headers = std::vector<std::pair<std::string, std::string>>;
- using SuccessCallback = base::Callback<void(const Response&)>;
+ using SendRequestCallback =
+ base::Callback<void(std::unique_ptr<Response> response, ErrorPtr error)>;
virtual void SendRequest(Method method,
const std::string& url,
const Headers& headers,
const std::string& data,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) = 0;
+ const SendRequestCallback& callback) = 0;
protected:
virtual ~HttpClient() = default;
diff --git a/libweave/include/weave/provider/network.h b/libweave/include/weave/provider/network.h
index cb3730e..4c0d5dd 100644
--- a/libweave/include/weave/provider/network.h
+++ b/libweave/include/weave/provider/network.h
@@ -29,8 +29,8 @@
using ConnectionChangedCallback = base::Closure;
// Callback type for OpenSslSocket.
- using OpenSslSocketSuccessCallback =
- base::Callback<void(std::unique_ptr<Stream> stream)>;
+ using OpenSslSocketCallback =
+ base::Callback<void(std::unique_ptr<Stream> stream, ErrorPtr error)>;
// Subscribes to notification about changes in network connectivity. Changes
// may include but not limited: interface up or down, new IP was assigned,
@@ -42,11 +42,9 @@
virtual State GetConnectionState() const = 0;
// Opens bidirectional sockets and returns attached stream.
- virtual void OpenSslSocket(
- const std::string& host,
- uint16_t port,
- const OpenSslSocketSuccessCallback& success_callback,
- const ErrorCallback& error_callback) = 0;
+ virtual void OpenSslSocket(const std::string& host,
+ uint16_t port,
+ const OpenSslSocketCallback& callback) = 0;
protected:
virtual ~Network() = default;
diff --git a/libweave/include/weave/provider/test/mock_http_client.h b/libweave/include/weave/provider/test/mock_http_client.h
index 72cb9b8..85ac154 100644
--- a/libweave/include/weave/provider/test/mock_http_client.h
+++ b/libweave/include/weave/provider/test/mock_http_client.h
@@ -27,13 +27,12 @@
public:
~MockHttpClient() override = default;
- MOCK_METHOD6(SendRequest,
+ MOCK_METHOD5(SendRequest,
void(Method,
const std::string&,
const Headers&,
const std::string&,
- const SuccessCallback&,
- const ErrorCallback&));
+ const SendRequestCallback&));
};
} // namespace test
diff --git a/libweave/include/weave/provider/test/mock_network.h b/libweave/include/weave/provider/test/mock_network.h
index e38dde2..b2811dc 100644
--- a/libweave/include/weave/provider/test/mock_network.h
+++ b/libweave/include/weave/provider/test/mock_network.h
@@ -20,11 +20,10 @@
MOCK_METHOD1(AddConnectionChangedCallback,
void(const ConnectionChangedCallback&));
MOCK_CONST_METHOD0(GetConnectionState, State());
- MOCK_METHOD4(OpenSslSocket,
+ MOCK_METHOD3(OpenSslSocket,
void(const std::string&,
uint16_t,
- const OpenSslSocketSuccessCallback&,
- const ErrorCallback&));
+ const OpenSslSocketCallback&));
};
} // namespace test
diff --git a/libweave/include/weave/provider/test/mock_wifi.h b/libweave/include/weave/provider/test/mock_wifi.h
index 6c53d9a..5798872 100644
--- a/libweave/include/weave/provider/test/mock_wifi.h
+++ b/libweave/include/weave/provider/test/mock_wifi.h
@@ -17,11 +17,10 @@
class MockWifi : public Wifi {
public:
- MOCK_METHOD4(Connect,
+ MOCK_METHOD3(Connect,
void(const std::string&,
const std::string&,
- const base::Closure&,
- const ErrorCallback&));
+ const DoneCallback&));
MOCK_METHOD1(StartAccessPoint, void(const std::string&));
MOCK_METHOD0(StopAccessPoint, void());
};
diff --git a/libweave/include/weave/provider/wifi.h b/libweave/include/weave/provider/wifi.h
index 51f370c..aaf4609 100644
--- a/libweave/include/weave/provider/wifi.h
+++ b/libweave/include/weave/provider/wifi.h
@@ -20,8 +20,7 @@
// should post either of callbacks.
virtual void Connect(const std::string& ssid,
const std::string& passphrase,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) = 0;
+ const DoneCallback& callback) = 0;
// Starts WiFi access point for wifi setup.
virtual void StartAccessPoint(const std::string& ssid) = 0;
diff --git a/libweave/include/weave/stream.h b/libweave/include/weave/stream.h
index 9d4d6fd..ea8af23 100644
--- a/libweave/include/weave/stream.h
+++ b/libweave/include/weave/stream.h
@@ -18,15 +18,14 @@
virtual ~InputStream() = default;
// Callback type for Read.
- using ReadSuccessCallback = base::Callback<void(size_t size)>;
+ using ReadCallback = base::Callback<void(size_t size, ErrorPtr error)>;
- // Implementation should return immediately and post either success_callback
- // or error_callback. Caller guarantees that buffet is alive until either of
- // callback is called.
+ // Implementation should return immediately and post callback after
+ // completing operation. Caller guarantees that buffet is alive until callback
+ // is called.
virtual void Read(void* buffer,
size_t size_to_read,
- const ReadSuccessCallback& success_callback,
- const ErrorCallback& error_callback) = 0;
+ const ReadCallback& callback) = 0;
};
// Interface for async input streaming.
@@ -34,14 +33,15 @@
public:
virtual ~OutputStream() = default;
- // Implementation should return immediately and post either success_callback
- // or error_callback. Caller guarantees that buffet is alive until either of
- // callback is called.
+ using WriteCallback = base::Callback<void(ErrorPtr error)>;
+
+ // Implementation should return immediately and post callback after
+ // completing operation. Caller guarantees that buffet is alive until either
+ // of callback is called.
// Success callback must be called only after all data is written.
virtual void Write(const void* buffer,
size_t size_to_write,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) = 0;
+ const WriteCallback& callback) = 0;
};
// Interface for async bi-directional streaming.
diff --git a/libweave/include/weave/test/fake_stream.h b/libweave/include/weave/test/fake_stream.h
index 8abb491..e5e0b57 100644
--- a/libweave/include/weave/test/fake_stream.h
+++ b/libweave/include/weave/test/fake_stream.h
@@ -31,12 +31,10 @@
void CancelPendingOperations() override;
void Read(void* buffer,
size_t size_to_read,
- const ReadSuccessCallback& success_callback,
- const ErrorCallback& error_callback) override;
+ const ReadCallback& callback) override;
void Write(const void* buffer,
size_t size_to_write,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) override;
+ const WriteCallback& callback) override;
private:
provider::TaskRunner* task_runner_{nullptr};
diff --git a/libweave/include/weave/test/mock_device.h b/libweave/include/weave/test/mock_device.h
index 0d8a6da..c0a75bc 100644
--- a/libweave/include/weave/test/mock_device.h
+++ b/libweave/include/weave/test/mock_device.h
@@ -44,10 +44,9 @@
MOCK_CONST_METHOD0(GetGcdState, GcdState());
MOCK_METHOD1(AddGcdStateChangedCallback,
void(const GcdStateChangedCallback& callback));
- MOCK_METHOD3(Register,
+ MOCK_METHOD2(Register,
void(const std::string& ticket_id,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback));
+ const DoneCallback& callback));
MOCK_METHOD2(AddPairingChangedCallbacks,
void(const PairingBeginCallback& begin_callback,
const PairingEndCallback& end_callback));
diff --git a/libweave/src/commands/cloud_command_proxy.cc b/libweave/src/commands/cloud_command_proxy.cc
index 3826378..41dc8ee 100644
--- a/libweave/src/commands/cloud_command_proxy.cc
+++ b/libweave/src/commands/cloud_command_proxy.cc
@@ -139,10 +139,8 @@
command_update_in_progress_ = true;
cloud_command_updater_->UpdateCommand(
command_instance_->GetID(), *update_queue_.front().second,
- base::Bind(&CloudCommandProxy::OnUpdateCommandFinished,
- weak_ptr_factory_.GetWeakPtr(), true),
- base::Bind(&CloudCommandProxy::OnUpdateCommandFinished,
- weak_ptr_factory_.GetWeakPtr(), false));
+ base::Bind(&CloudCommandProxy::OnUpdateCommandDone,
+ weak_ptr_factory_.GetWeakPtr()));
}
void CloudCommandProxy::ResendCommandUpdate() {
@@ -150,10 +148,10 @@
SendCommandUpdate();
}
-void CloudCommandProxy::OnUpdateCommandFinished(bool success) {
+void CloudCommandProxy::OnUpdateCommandDone(ErrorPtr error) {
command_update_in_progress_ = false;
- cloud_backoff_entry_->InformOfRequest(success);
- if (success) {
+ cloud_backoff_entry_->InformOfRequest(!error);
+ if (!error) {
// Remove the succeeded update from the queue.
update_queue_.pop_front();
}
diff --git a/libweave/src/commands/cloud_command_proxy.h b/libweave/src/commands/cloud_command_proxy.h
index 86dc5d3..b2ef11a 100644
--- a/libweave/src/commands/cloud_command_proxy.h
+++ b/libweave/src/commands/cloud_command_proxy.h
@@ -63,9 +63,7 @@
void ResendCommandUpdate();
// Callback invoked by the asynchronous PATCH request to the server.
- // Called both in a case of successfully updating server command resource
- // and in case of an error, indicated by the |success| parameter.
- void OnUpdateCommandFinished(bool success);
+ void OnUpdateCommandDone(ErrorPtr error);
// Callback invoked by the device state change queue to notify of the
// successful device state update. |update_id| is the ID of the state that
diff --git a/libweave/src/commands/cloud_command_proxy_unittest.cc b/libweave/src/commands/cloud_command_proxy_unittest.cc
index 3c7692f..d9518e4 100644
--- a/libweave/src/commands/cloud_command_proxy_unittest.cc
+++ b/libweave/src/commands/cloud_command_proxy_unittest.cc
@@ -38,11 +38,10 @@
class MockCloudCommandUpdateInterface : public CloudCommandUpdateInterface {
public:
- MOCK_METHOD4(UpdateCommand,
+ MOCK_METHOD3(UpdateCommand,
void(const std::string&,
const base::DictionaryValue&,
- const base::Closure&,
- const base::Closure&));
+ const DoneCallback&));
};
// Test back-off entry that uses the test clock.
@@ -147,7 +146,7 @@
TEST_F(CloudCommandProxyTest, ImmediateUpdate) {
const char expected[] = "{'state':'done'}";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _));
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _));
command_instance_->Complete({}, nullptr);
task_runner_.RunOnce();
}
@@ -161,7 +160,7 @@
callbacks_.Notify(19);
// Now we should get the update...
const char expected[] = "{'state':'done'}";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _));
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _));
callbacks_.Notify(20);
}
@@ -170,14 +169,14 @@
// state=inProgress
// progress={...}
// The first state update is sent immediately, the second should be delayed.
- base::Closure on_success;
+ DoneCallback callback;
EXPECT_CALL(
cloud_updater_,
UpdateCommand(
kCmdID,
- MatchJson("{'state':'inProgress', 'progress':{'status':'ready'}}"), _,
+ MatchJson("{'state':'inProgress', 'progress':{'status':'ready'}}"),
_))
- .WillOnce(SaveArg<2>(&on_success));
+ .WillOnce(SaveArg<2>(&callback));
EXPECT_TRUE(command_instance_->SetProgress(
*CreateDictionaryValue("{'status': 'ready'}"), nullptr));
@@ -200,34 +199,35 @@
'progress': {'status':'ready'},
'state':'inProgress'
})";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _));
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _));
callbacks_.Notify(20);
}
TEST_F(CloudCommandProxyTest, RetryFailed) {
- base::Closure on_success;
- base::Closure on_error;
+ DoneCallback callback;
const char expect[] =
"{'state':'inProgress', 'progress': {'status': 'ready'}}";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect), _, _))
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect), _))
.Times(3)
- .WillRepeatedly(DoAll(SaveArg<2>(&on_success), SaveArg<3>(&on_error)));
+ .WillRepeatedly(SaveArg<2>(&callback));
auto started = task_runner_.GetClock()->Now();
EXPECT_TRUE(command_instance_->SetProgress(
*CreateDictionaryValue("{'status': 'ready'}"), nullptr));
task_runner_.Run();
- on_error.Run();
+ ErrorPtr error;
+ Error::AddTo(&error, FROM_HERE, "TEST", "TEST", "TEST");
+ callback.Run(error->Clone());
task_runner_.Run();
EXPECT_GE(task_runner_.GetClock()->Now() - started,
base::TimeDelta::FromSecondsD(0.9));
- on_error.Run();
+ callback.Run(error->Clone());
task_runner_.Run();
EXPECT_GE(task_runner_.GetClock()->Now() - started,
base::TimeDelta::FromSecondsD(2.9));
- on_success.Run();
+ callback.Run(nullptr);
task_runner_.Run();
EXPECT_GE(task_runner_.GetClock()->Now() - started,
base::TimeDelta::FromSecondsD(2.9));
@@ -244,20 +244,20 @@
command_instance_->Complete({}, nullptr);
// Device state #20 updated.
- base::Closure on_success;
+ DoneCallback callback;
const char expect1[] = R"({
'progress': {'status':'ready'},
'state':'inProgress'
})";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect1), _, _))
- .WillOnce(SaveArg<2>(&on_success));
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect1), _))
+ .WillOnce(SaveArg<2>(&callback));
callbacks_.Notify(20);
- on_success.Run();
+ callback.Run(nullptr);
// Device state #21 updated.
const char expect2[] = "{'progress': {'status':'busy'}}";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect2), _, _))
- .WillOnce(SaveArg<2>(&on_success));
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect2), _))
+ .WillOnce(SaveArg<2>(&callback));
callbacks_.Notify(21);
// Device state #22 updated. Nothing happens here since the previous command
@@ -267,9 +267,9 @@
// Now the command update is complete, send out the patch that happened after
// the state #22 was updated.
const char expect3[] = "{'state': 'done'}";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect3), _, _))
- .WillOnce(SaveArg<2>(&on_success));
- on_success.Run();
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect3), _))
+ .WillOnce(SaveArg<2>(&callback));
+ callback.Run(nullptr);
}
TEST_F(CloudCommandProxyTest, CombineSomeStates) {
@@ -283,22 +283,22 @@
command_instance_->Complete({}, nullptr);
// Device state 20-21 updated.
- base::Closure on_success;
+ DoneCallback callback;
const char expect1[] = R"({
'progress': {'status':'busy'},
'state':'inProgress'
})";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect1), _, _))
- .WillOnce(SaveArg<2>(&on_success));
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect1), _))
+ .WillOnce(SaveArg<2>(&callback));
callbacks_.Notify(21);
- on_success.Run();
+ callback.Run(nullptr);
// Device state #22 updated.
const char expect2[] = "{'state': 'done'}";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect2), _, _))
- .WillOnce(SaveArg<2>(&on_success));
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expect2), _))
+ .WillOnce(SaveArg<2>(&callback));
callbacks_.Notify(22);
- on_success.Run();
+ callback.Run(nullptr);
}
TEST_F(CloudCommandProxyTest, CombineAllStates) {
@@ -316,7 +316,7 @@
'progress': {'status':'busy'},
'state':'done'
})";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _));
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _));
callbacks_.Notify(30);
}
@@ -336,7 +336,7 @@
'results': {'sum':30},
'state':'done'
})";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _));
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _));
callbacks_.Notify(30);
}
@@ -352,7 +352,7 @@
// As soon as we change the command, the update to the server should be sent.
const char expected[] = "{'state':'done'}";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _));
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _));
command_instance_->Complete({}, nullptr);
task_runner_.RunOnce();
}
@@ -370,7 +370,7 @@
// Only when the state #20 is published we should update the command
const char expected[] = "{'state':'done'}";
- EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _, _));
+ EXPECT_CALL(cloud_updater_, UpdateCommand(kCmdID, MatchJson(expected), _));
callbacks_.Notify(20);
}
diff --git a/libweave/src/commands/cloud_command_update_interface.h b/libweave/src/commands/cloud_command_update_interface.h
index 5bd5c11..71c2f18 100644
--- a/libweave/src/commands/cloud_command_update_interface.h
+++ b/libweave/src/commands/cloud_command_update_interface.h
@@ -18,8 +18,7 @@
public:
virtual void UpdateCommand(const std::string& command_id,
const base::DictionaryValue& command_patch,
- const base::Closure& on_success,
- const base::Closure& on_error) = 0;
+ const DoneCallback& callback) = 0;
protected:
virtual ~CloudCommandUpdateInterface() = default;
diff --git a/libweave/src/device_manager.cc b/libweave/src/device_manager.cc
index 9105352..54570f6 100644
--- a/libweave/src/device_manager.cc
+++ b/libweave/src/device_manager.cc
@@ -152,9 +152,8 @@
}
void DeviceManager::Register(const std::string& ticket_id,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
- device_info_->RegisterDevice(ticket_id, success_callback, error_callback);
+ const DoneCallback& callback) {
+ device_info_->RegisterDevice(ticket_id, callback);
}
void DeviceManager::AddPairingChangedCallbacks(
diff --git a/libweave/src/device_manager.h b/libweave/src/device_manager.h
index 16e0601..7bded15 100644
--- a/libweave/src/device_manager.h
+++ b/libweave/src/device_manager.h
@@ -59,8 +59,7 @@
ErrorPtr* error) override;
std::unique_ptr<base::DictionaryValue> GetState() const override;
void Register(const std::string& ticket_id,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) override;
+ const DoneCallback& callback) override;
GcdState GetGcdState() const override;
void AddGcdStateChangedCallback(
const GcdStateChangedCallback& callback) override;
diff --git a/libweave/src/device_registration_info.cc b/libweave/src/device_registration_info.cc
index 6481481..e35b768 100644
--- a/libweave/src/device_registration_info.cc
+++ b/libweave/src/device_registration_info.cc
@@ -98,17 +98,18 @@
return AppendQueryParams(result, params);
}
-void IgnoreCloudError(ErrorPtr) {}
-
void IgnoreCloudErrorWithCallback(const base::Closure& cb, ErrorPtr) {
cb.Run();
}
-void IgnoreCloudResult(const base::DictionaryValue&) {}
+void IgnoreCloudError(ErrorPtr) {}
-void IgnoreCloudResultWithCallback(const base::Closure& cb,
- const base::DictionaryValue&) {
- cb.Run();
+void IgnoreCloudResult(const base::DictionaryValue&, ErrorPtr error) {}
+
+void IgnoreCloudResultWithCallback(const DoneCallback& cb,
+ const base::DictionaryValue&,
+ ErrorPtr error) {
+ cb.Run(std::move(error));
}
class RequestSender final {
@@ -118,31 +119,28 @@
HttpClient* transport)
: method_{method}, url_{url}, transport_{transport} {}
- void Send(const HttpClient::SuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+ void Send(const HttpClient::SendRequestCallback& callback) {
static int debug_id = 0;
++debug_id;
VLOG(1) << "Sending request. id:" << debug_id
<< " method:" << EnumToString(method_) << " url:" << url_;
VLOG(2) << "Request data: " << data_;
- auto on_success = [](int debug_id,
- const HttpClient::SuccessCallback& success_callback,
- const HttpClient::Response& response) {
- VLOG(1) << "Request succeeded. id:" << debug_id << " status:" <<
- response.GetStatusCode();
- VLOG(2) << "Response data: " << response.GetData();
- success_callback.Run(response);
- };
- auto on_error = [](int debug_id, const ErrorCallback& error_callback,
- ErrorPtr error) {
- VLOG(1) << "Request failed, id=" << debug_id
- << ", reason: " << error->GetCode()
- << ", message: " << error->GetMessage();
- error_callback.Run(std::move(error));
+ auto on_done = [](
+ int debug_id, const HttpClient::SendRequestCallback& callback,
+ std::unique_ptr<HttpClient::Response> response, ErrorPtr error) {
+ if (error) {
+ VLOG(1) << "Request failed, id=" << debug_id
+ << ", reason: " << error->GetCode()
+ << ", message: " << error->GetMessage();
+ return callback.Run({}, std::move(error));
+ }
+ VLOG(1) << "Request succeeded. id:" << debug_id
+ << " status:" << response->GetStatusCode();
+ VLOG(2) << "Response data: " << response->GetData();
+ callback.Run(std::move(response), nullptr);
};
transport_->SendRequest(method_, url_, GetFullHeaders(), data_,
- base::Bind(on_success, debug_id, success_callback),
- base::Bind(on_error, debug_id, error_callback));
+ base::Bind(on_done, debug_id, callback));
}
void SetAccessToken(const std::string& access_token) {
@@ -304,7 +302,8 @@
return; // Assume we're in test
task_runner_->PostDelayedTask(
FROM_HERE,
- base::Bind(&DeviceRegistrationInfo::ConnectToCloud, AsWeakPtr()), delay);
+ base::Bind(&DeviceRegistrationInfo::ConnectToCloud, AsWeakPtr(), nullptr),
+ delay);
}
bool DeviceRegistrationInfo::HaveRegistrationCredentials() const {
@@ -350,16 +349,12 @@
return resp;
}
-void DeviceRegistrationInfo::RefreshAccessToken(
- const base::Closure& success_callback,
- const ErrorCallback& error_callback) {
+void DeviceRegistrationInfo::RefreshAccessToken(const DoneCallback& callback) {
LOG(INFO) << "Refreshing access token.";
ErrorPtr error;
- if (!VerifyRegistrationCredentials(&error)) {
- error_callback.Run(std::move(error));
- return;
- }
+ if (!VerifyRegistrationCredentials(&error))
+ return callback.Run(std::move(error));
if (oauth2_backoff_entry_->ShouldRejectRequest()) {
VLOG(1) << "RefreshToken request delayed for "
@@ -367,19 +362,11 @@
<< " due to backoff policy";
task_runner_->PostDelayedTask(
FROM_HERE, base::Bind(&DeviceRegistrationInfo::RefreshAccessToken,
- AsWeakPtr(), success_callback, error_callback),
+ AsWeakPtr(), callback),
oauth2_backoff_entry_->GetTimeUntilRelease());
return;
}
- // Make a shared pointer to |success_callback| and |error_callback| since we
- // are going to share these callbacks with both success and error callbacks
- // for PostFormData() and if the callbacks have any move-only types,
- // one of the copies will be bad.
- auto shared_success_callback =
- std::make_shared<base::Closure>(success_callback);
- auto shared_error_callback = std::make_shared<ErrorCallback>(error_callback);
-
RequestSender sender{HttpClient::Method::kPost, GetOAuthURL("token"),
http_client_};
sender.SetFormData({
@@ -388,25 +375,25 @@
{"client_secret", GetSettings().client_secret},
{"grant_type", "refresh_token"},
});
- sender.Send(base::Bind(&DeviceRegistrationInfo::OnRefreshAccessTokenSuccess,
- weak_factory_.GetWeakPtr(), shared_success_callback,
- shared_error_callback),
- base::Bind(&DeviceRegistrationInfo::OnRefreshAccessTokenError,
- weak_factory_.GetWeakPtr(), shared_success_callback,
- shared_error_callback));
+ sender.Send(base::Bind(&DeviceRegistrationInfo::OnRefreshAccessTokenDone,
+ weak_factory_.GetWeakPtr(), callback));
VLOG(1) << "Refresh access token request dispatched";
}
-void DeviceRegistrationInfo::OnRefreshAccessTokenSuccess(
- const std::shared_ptr<base::Closure>& success_callback,
- const std::shared_ptr<ErrorCallback>& error_callback,
- const HttpClient::Response& response) {
+void DeviceRegistrationInfo::OnRefreshAccessTokenDone(
+ const DoneCallback& callback,
+ std::unique_ptr<HttpClient::Response> response,
+ ErrorPtr error) {
+ if (error) {
+ VLOG(1) << "Refresh access token failed";
+ oauth2_backoff_entry_->InformOfRequest(false);
+ return RefreshAccessToken(callback);
+ }
VLOG(1) << "Refresh access token request completed";
oauth2_backoff_entry_->InformOfRequest(true);
- ErrorPtr error;
- auto json = ParseOAuthResponse(response, &error);
+ auto json = ParseOAuthResponse(*response, &error);
if (!json)
- return error_callback->Run(std::move(error));
+ return callback.Run(std::move(error));
int expires_in = 0;
if (!json->GetString("access_token", &access_token_) ||
@@ -415,7 +402,7 @@
LOG(ERROR) << "Access token unavailable.";
Error::AddTo(&error, FROM_HERE, kErrorDomainOAuth2,
"unexpected_server_response", "Access token unavailable");
- return error_callback->Run(std::move(error));
+ return callback.Run(std::move(error));
}
access_token_expiration_ =
base::Time::Now() + base::TimeDelta::FromSeconds(expires_in);
@@ -428,16 +415,7 @@
// Now that we have a new access token, retry the connection.
StartNotificationChannel();
}
- success_callback->Run();
-}
-
-void DeviceRegistrationInfo::OnRefreshAccessTokenError(
- const std::shared_ptr<base::Closure>& success_callback,
- const std::shared_ptr<ErrorCallback>& error_callback,
- ErrorPtr error) {
- VLOG(1) << "Refresh access token failed";
- oauth2_backoff_entry_->InformOfRequest(false);
- RefreshAccessToken(*success_callback, *error_callback);
+ callback.Run(nullptr);
}
void DeviceRegistrationInfo::StartNotificationChannel() {
@@ -522,45 +500,27 @@
}
void DeviceRegistrationInfo::GetDeviceInfo(
- const CloudRequestCallback& success_callback,
- const ErrorCallback& error_callback) {
+ const CloudRequestDoneCallback& callback) {
ErrorPtr error;
if (!VerifyRegistrationCredentials(&error)) {
- if (!error_callback.is_null())
- error_callback.Run(std::move(error));
- return;
+ return callback.Run({}, std::move(error));
}
- DoCloudRequest(HttpClient::Method::kGet, GetDeviceURL(), nullptr,
- success_callback, error_callback);
+ DoCloudRequest(HttpClient::Method::kGet, GetDeviceURL(), nullptr, callback);
}
-struct DeviceRegistrationInfo::RegisterCallbacks {
- RegisterCallbacks(const SuccessCallback& success, const ErrorCallback& error)
- : success_callback{success}, error_callback{error} {}
- SuccessCallback success_callback;
- ErrorCallback error_callback;
-};
-
-void DeviceRegistrationInfo::RegisterDeviceError(
- const std::shared_ptr<RegisterCallbacks>& callbacks,
- ErrorPtr error) {
- task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(callbacks->error_callback, base::Passed(&error)),
- {});
+void DeviceRegistrationInfo::RegisterDeviceError(const DoneCallback& callback,
+ ErrorPtr error) {
+ task_runner_->PostDelayedTask(FROM_HERE,
+ base::Bind(callback, base::Passed(&error)), {});
}
-void DeviceRegistrationInfo::RegisterDevice(
- const std::string& ticket_id,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
- auto callbacks =
- std::make_shared<RegisterCallbacks>(success_callback, error_callback);
-
+void DeviceRegistrationInfo::RegisterDevice(const std::string& ticket_id,
+ const DoneCallback& callback) {
ErrorPtr error;
std::unique_ptr<base::DictionaryValue> device_draft =
BuildDeviceResource(&error);
if (!device_draft)
- return RegisterDeviceError(callbacks, std::move(error));
+ return RegisterDeviceError(callback, std::move(error));
base::DictionaryValue req_json;
req_json.SetString("id", ticket_id);
@@ -573,23 +533,23 @@
RequestSender sender{HttpClient::Method::kPatch, url, http_client_};
sender.SetJsonData(req_json);
sender.Send(base::Bind(&DeviceRegistrationInfo::RegisterDeviceOnTicketSent,
- weak_factory_.GetWeakPtr(), ticket_id, callbacks),
- base::Bind(&DeviceRegistrationInfo::RegisterDeviceError,
- weak_factory_.GetWeakPtr(), callbacks));
+ weak_factory_.GetWeakPtr(), ticket_id, callback));
}
void DeviceRegistrationInfo::RegisterDeviceOnTicketSent(
const std::string& ticket_id,
- const std::shared_ptr<RegisterCallbacks>& callbacks,
- const provider::HttpClient::Response& response) {
- ErrorPtr error;
- auto json_resp = ParseJsonResponse(response, &error);
+ const DoneCallback& callback,
+ std::unique_ptr<provider::HttpClient::Response> response,
+ ErrorPtr error) {
+ if (error)
+ return RegisterDeviceError(callback, std::move(error));
+ auto json_resp = ParseJsonResponse(*response, &error);
if (!json_resp)
- return RegisterDeviceError(callbacks, std::move(error));
+ return RegisterDeviceError(callback, std::move(error));
- if (!IsSuccessful(response)) {
+ if (!IsSuccessful(*response)) {
ParseGCDError(json_resp.get(), &error);
- return RegisterDeviceError(callbacks, std::move(error));
+ return RegisterDeviceError(callback, std::move(error));
}
std::string url =
@@ -597,21 +557,21 @@
{{"key", GetSettings().api_key}});
RequestSender{HttpClient::Method::kPost, url, http_client_}.Send(
base::Bind(&DeviceRegistrationInfo::RegisterDeviceOnTicketFinalized,
- weak_factory_.GetWeakPtr(), callbacks),
- base::Bind(&DeviceRegistrationInfo::RegisterDeviceError,
- weak_factory_.GetWeakPtr(), callbacks));
+ weak_factory_.GetWeakPtr(), callback));
}
void DeviceRegistrationInfo::RegisterDeviceOnTicketFinalized(
- const std::shared_ptr<RegisterCallbacks>& callbacks,
- const provider::HttpClient::Response& response) {
- ErrorPtr error;
- auto json_resp = ParseJsonResponse(response, &error);
+ const DoneCallback& callback,
+ std::unique_ptr<provider::HttpClient::Response> response,
+ ErrorPtr error) {
+ if (error)
+ return RegisterDeviceError(callback, std::move(error));
+ auto json_resp = ParseJsonResponse(*response, &error);
if (!json_resp)
- return RegisterDeviceError(callbacks, std::move(error));
- if (!IsSuccessful(response)) {
+ return RegisterDeviceError(callback, std::move(error));
+ if (!IsSuccessful(*response)) {
ParseGCDError(json_resp.get(), &error);
- return RegisterDeviceError(callbacks, std::move(error));
+ return RegisterDeviceError(callback, std::move(error));
}
std::string auth_code;
@@ -624,7 +584,7 @@
!device_draft_response->GetString("id", &cloud_id)) {
Error::AddTo(&error, FROM_HERE, kErrorDomainGCD, "unexpected_response",
"Device account missing in response");
- return RegisterDeviceError(callbacks, std::move(error));
+ return RegisterDeviceError(callback, std::move(error));
}
UpdateDeviceInfoTimestamp(*device_draft_response);
@@ -641,18 +601,18 @@
{"grant_type", "authorization_code"}});
sender2.Send(base::Bind(&DeviceRegistrationInfo::RegisterDeviceOnAuthCodeSent,
weak_factory_.GetWeakPtr(), cloud_id, robot_account,
- callbacks),
- base::Bind(&DeviceRegistrationInfo::RegisterDeviceError,
- weak_factory_.GetWeakPtr(), callbacks));
+ callback));
}
void DeviceRegistrationInfo::RegisterDeviceOnAuthCodeSent(
const std::string& cloud_id,
const std::string& robot_account,
- const std::shared_ptr<RegisterCallbacks>& callbacks,
- const provider::HttpClient::Response& response) {
- ErrorPtr error;
- auto json_resp = ParseOAuthResponse(response, &error);
+ const DoneCallback& callback,
+ std::unique_ptr<provider::HttpClient::Response> response,
+ ErrorPtr error) {
+ if (error)
+ return RegisterDeviceError(callback, std::move(error));
+ auto json_resp = ParseOAuthResponse(*response, &error);
int expires_in = 0;
std::string refresh_token;
if (!json_resp || !json_resp->GetString("access_token", &access_token_) ||
@@ -661,7 +621,7 @@
access_token_.empty() || refresh_token.empty() || expires_in <= 0) {
Error::AddTo(&error, FROM_HERE, kErrorDomainGCD, "unexpected_response",
"Device access_token missing in response");
- return RegisterDeviceError(callbacks, std::move(error));
+ return RegisterDeviceError(callback, std::move(error));
}
access_token_expiration_ =
@@ -673,7 +633,7 @@
change.set_refresh_token(refresh_token);
change.Commit();
- task_runner_->PostDelayedTask(FROM_HERE, callbacks->success_callback, {});
+ task_runner_->PostDelayedTask(FROM_HERE, base::Bind(callback, nullptr), {});
StartNotificationChannel();
@@ -686,10 +646,9 @@
HttpClient::Method method,
const std::string& url,
const base::DictionaryValue* body,
- const CloudRequestCallback& success_callback,
- const ErrorCallback& error_callback) {
+ const CloudRequestDoneCallback& callback) {
// We make CloudRequestData shared here because we want to make sure
- // there is only one instance of success_callback and error_calback since
+ // there is only one instance of callback and error_calback since
// those may have move-only types and making a copy of the callback with
// move-only types curried-in will invalidate the source callback.
auto data = std::make_shared<CloudRequestData>();
@@ -697,8 +656,7 @@
data->url = url;
if (body)
base::JSONWriter::Write(*body, &data->body);
- data->success_callback = success_callback;
- data->error_callback = error_callback;
+ data->callback = callback;
SendCloudRequest(data);
}
@@ -710,7 +668,7 @@
ErrorPtr error;
if (!VerifyRegistrationCredentials(&error)) {
- return data->error_callback.Run(std::move(error));
+ return data->callback.Run({}, std::move(error));
}
if (cloud_backoff_entry_->ShouldRejectRequest()) {
@@ -726,22 +684,21 @@
RequestSender sender{data->method, data->url, http_client_};
sender.SetData(data->body, http::kJsonUtf8);
sender.SetAccessToken(access_token_);
- sender.Send(base::Bind(&DeviceRegistrationInfo::OnCloudRequestSuccess,
- AsWeakPtr(), data),
- base::Bind(&DeviceRegistrationInfo::OnCloudRequestError,
+ sender.Send(base::Bind(&DeviceRegistrationInfo::OnCloudRequestDone,
AsWeakPtr(), data));
}
-void DeviceRegistrationInfo::OnCloudRequestSuccess(
+void DeviceRegistrationInfo::OnCloudRequestDone(
const std::shared_ptr<const CloudRequestData>& data,
- const HttpClient::Response& response) {
- int status_code = response.GetStatusCode();
+ std::unique_ptr<provider::HttpClient::Response> response,
+ ErrorPtr error) {
+ if (error)
+ return RetryCloudRequest(data);
+ int status_code = response->GetStatusCode();
if (status_code == http::kDenied) {
cloud_backoff_entry_->InformOfRequest(true);
RefreshAccessToken(
base::Bind(&DeviceRegistrationInfo::OnAccessTokenRefreshed, AsWeakPtr(),
- data),
- base::Bind(&DeviceRegistrationInfo::OnAccessTokenError, AsWeakPtr(),
data));
return;
}
@@ -755,14 +712,13 @@
return;
}
- ErrorPtr error;
- auto json_resp = ParseJsonResponse(response, &error);
+ auto json_resp = ParseJsonResponse(*response, &error);
if (!json_resp) {
cloud_backoff_entry_->InformOfRequest(true);
- return data->error_callback.Run(std::move(error));
+ return data->callback.Run({}, std::move(error));
}
- if (!IsSuccessful(response)) {
+ if (!IsSuccessful(*response)) {
ParseGCDError(json_resp.get(), &error);
if (status_code == http::kForbidden &&
error->HasError(kErrorDomainGCDServer, "rateLimitExceeded")) {
@@ -770,18 +726,12 @@
return RetryCloudRequest(data);
}
cloud_backoff_entry_->InformOfRequest(true);
- return data->error_callback.Run(std::move(error));
+ return data->callback.Run({}, std::move(error));
}
cloud_backoff_entry_->InformOfRequest(true);
SetGcdState(GcdState::kConnected);
- data->success_callback.Run(*json_resp);
-}
-
-void DeviceRegistrationInfo::OnCloudRequestError(
- const std::shared_ptr<const CloudRequestData>& data,
- ErrorPtr error) {
- RetryCloudRequest(data);
+ data->callback.Run(*json_resp, nullptr);
}
void DeviceRegistrationInfo::RetryCloudRequest(
@@ -793,32 +743,34 @@
}
void DeviceRegistrationInfo::OnAccessTokenRefreshed(
- const std::shared_ptr<const CloudRequestData>& data) {
+ const std::shared_ptr<const CloudRequestData>& data,
+ ErrorPtr error) {
+ if (error) {
+ CheckAccessTokenError(error->Clone());
+ return data->callback.Run({}, std::move(error));
+ }
SendCloudRequest(data);
}
-void DeviceRegistrationInfo::OnAccessTokenError(
- const std::shared_ptr<const CloudRequestData>& data,
- ErrorPtr error) {
- CheckAccessTokenError(error->Clone());
- data->error_callback.Run(std::move(error));
-}
-
void DeviceRegistrationInfo::CheckAccessTokenError(ErrorPtr error) {
- if (error->HasError(kErrorDomainOAuth2, "invalid_grant"))
+ if (error && error->HasError(kErrorDomainOAuth2, "invalid_grant"))
MarkDeviceUnregistered();
}
-void DeviceRegistrationInfo::ConnectToCloud() {
+void DeviceRegistrationInfo::ConnectToCloud(ErrorPtr error) {
+ if (error) {
+ if (error->HasError(kErrorDomainOAuth2, "invalid_grant"))
+ MarkDeviceUnregistered();
+ return;
+ }
+
connected_to_cloud_ = false;
if (!VerifyRegistrationCredentials(nullptr))
return;
if (access_token_.empty()) {
RefreshAccessToken(
- base::Bind(&DeviceRegistrationInfo::ConnectToCloud, AsWeakPtr()),
- base::Bind(&DeviceRegistrationInfo::CheckAccessTokenError,
- AsWeakPtr()));
+ base::Bind(&DeviceRegistrationInfo::ConnectToCloud, AsWeakPtr()));
return;
}
@@ -828,16 +780,16 @@
// 3) abort any commands that we've previously marked as "in progress"
// or as being in an error state; publish queued commands
UpdateDeviceResource(
- base::Bind(&DeviceRegistrationInfo::OnConnectedToCloud, AsWeakPtr()),
- base::Bind(&IgnoreCloudError));
+ base::Bind(&DeviceRegistrationInfo::OnConnectedToCloud, AsWeakPtr()));
}
-void DeviceRegistrationInfo::OnConnectedToCloud() {
+void DeviceRegistrationInfo::OnConnectedToCloud(ErrorPtr error) {
+ if (error)
+ return;
LOG(INFO) << "Device connected to cloud server";
connected_to_cloud_ = true;
FetchCommands(base::Bind(&DeviceRegistrationInfo::ProcessInitialCommandList,
- AsWeakPtr()),
- base::Bind(&IgnoreCloudError));
+ AsWeakPtr()));
// In case there are any pending state updates since we sent off the initial
// UpdateDeviceResource() request, update the server with any state changes.
PublishStateUpdates();
@@ -853,8 +805,7 @@
change.Commit();
if (HaveRegistrationCredentials()) {
- UpdateDeviceResource(base::Bind(&base::DoNothing),
- base::Bind(&IgnoreCloudError));
+ UpdateDeviceResource(base::Bind(&IgnoreCloudError));
}
}
@@ -891,12 +842,10 @@
void DeviceRegistrationInfo::UpdateCommand(
const std::string& command_id,
const base::DictionaryValue& command_patch,
- const base::Closure& on_success,
- const base::Closure& on_error) {
+ const DoneCallback& callback) {
DoCloudRequest(HttpClient::Method::kPatch,
GetServiceURL("commands/" + command_id), &command_patch,
- base::Bind(&IgnoreCloudResultWithCallback, on_success),
- base::Bind(&IgnoreCloudErrorWithCallback, on_error));
+ base::Bind(&IgnoreCloudResultWithCallback, callback));
}
void DeviceRegistrationInfo::NotifyCommandAborted(const std::string& command_id,
@@ -908,14 +857,12 @@
command_patch.Set(commands::attributes::kCommand_Error,
ErrorInfoToJson(*error).release());
}
- UpdateCommand(command_id, command_patch, base::Bind(&base::DoNothing),
- base::Bind(&base::DoNothing));
+ UpdateCommand(command_id, command_patch, base::Bind(&IgnoreCloudError));
}
void DeviceRegistrationInfo::UpdateDeviceResource(
- const base::Closure& on_success,
- const ErrorCallback& on_failure) {
- queued_resource_update_callbacks_.emplace_back(on_success, on_failure);
+ const DoneCallback& callback) {
+ queued_resource_update_callbacks_.emplace_back(callback);
if (!in_progress_resource_update_callbacks_.empty()) {
VLOG(1) << "Another request is already pending.";
return;
@@ -935,10 +882,8 @@
// the request to guard against out-of-order requests overwriting settings
// specified by later requests.
VLOG(1) << "Getting the last device resource timestamp from server...";
- GetDeviceInfo(
- base::Bind(&DeviceRegistrationInfo::OnDeviceInfoRetrieved, AsWeakPtr()),
- base::Bind(&DeviceRegistrationInfo::OnUpdateDeviceResourceError,
- AsWeakPtr()));
+ GetDeviceInfo(base::Bind(&DeviceRegistrationInfo::OnDeviceInfoRetrieved,
+ AsWeakPtr()));
return;
}
@@ -959,16 +904,16 @@
std::string url = GetDeviceURL(
{}, {{"lastUpdateTimeMs", last_device_resource_updated_timestamp_}});
- DoCloudRequest(
- HttpClient::Method::kPut, url, device_resource.get(),
- base::Bind(&DeviceRegistrationInfo::OnUpdateDeviceResourceSuccess,
- AsWeakPtr()),
- base::Bind(&DeviceRegistrationInfo::OnUpdateDeviceResourceError,
- AsWeakPtr()));
+ DoCloudRequest(HttpClient::Method::kPut, url, device_resource.get(),
+ base::Bind(&DeviceRegistrationInfo::OnUpdateDeviceResourceDone,
+ AsWeakPtr()));
}
void DeviceRegistrationInfo::OnDeviceInfoRetrieved(
- const base::DictionaryValue& device_info) {
+ const base::DictionaryValue& device_info,
+ ErrorPtr error) {
+ if (error)
+ return OnUpdateDeviceResourceError(std::move(error));
if (UpdateDeviceInfoTimestamp(device_info))
StartQueuedUpdateDeviceResource();
}
@@ -987,15 +932,18 @@
return true;
}
-void DeviceRegistrationInfo::OnUpdateDeviceResourceSuccess(
- const base::DictionaryValue& device_info) {
+void DeviceRegistrationInfo::OnUpdateDeviceResourceDone(
+ const base::DictionaryValue& device_info,
+ ErrorPtr error) {
+ if (error)
+ return OnUpdateDeviceResourceError(std::move(error));
UpdateDeviceInfoTimestamp(device_info);
// Make a copy of the callback list so that if the callback triggers another
// call to UpdateDeviceResource(), we do not modify the list we are iterating
// over.
auto callback_list = std::move(in_progress_resource_update_callbacks_);
- for (const auto& callback_pair : callback_list)
- callback_pair.first.Run();
+ for (const auto& callback : callback_list)
+ callback.Run(nullptr);
StartQueuedUpdateDeviceResource();
}
@@ -1004,10 +952,8 @@
// If the server rejected our previous request, retrieve the latest
// timestamp from the server and retry.
VLOG(1) << "Getting the last device resource timestamp from server...";
- GetDeviceInfo(
- base::Bind(&DeviceRegistrationInfo::OnDeviceInfoRetrieved, AsWeakPtr()),
- base::Bind(&DeviceRegistrationInfo::OnUpdateDeviceResourceError,
- AsWeakPtr()));
+ GetDeviceInfo(base::Bind(&DeviceRegistrationInfo::OnDeviceInfoRetrieved,
+ AsWeakPtr()));
return;
}
@@ -1015,28 +961,24 @@
// call to UpdateDeviceResource(), we do not modify the list we are iterating
// over.
auto callback_list = std::move(in_progress_resource_update_callbacks_);
- for (const auto& callback_pair : callback_list)
- callback_pair.second.Run(error->Clone());
+ for (const auto& callback : callback_list)
+ callback.Run(error->Clone());
StartQueuedUpdateDeviceResource();
}
-void DeviceRegistrationInfo::OnFetchCommandsSuccess(
- const base::Callback<void(const base::ListValue&)>& callback,
- const base::DictionaryValue& json) {
+void DeviceRegistrationInfo::OnFetchCommandsDone(
+ const base::Callback<void(const base::ListValue&, ErrorPtr)>& callback,
+ const base::DictionaryValue& json,
+ ErrorPtr error) {
OnFetchCommandsReturned();
+ if (error)
+ return callback.Run({}, std::move(error));
const base::ListValue* commands{nullptr};
- if (!json.GetList("commands", &commands)) {
+ if (!json.GetList("commands", &commands))
VLOG(2) << "No commands in the response.";
- }
const base::ListValue empty;
- callback.Run(commands ? *commands : empty);
-}
-
-void DeviceRegistrationInfo::OnFetchCommandsError(const ErrorCallback& callback,
- ErrorPtr error) {
- OnFetchCommandsReturned();
- callback.Run(std::move(error));
+ callback.Run(commands ? *commands : empty, nullptr);
}
void DeviceRegistrationInfo::OnFetchCommandsReturned() {
@@ -1047,17 +989,15 @@
}
void DeviceRegistrationInfo::FetchCommands(
- const base::Callback<void(const base::ListValue&)>& on_success,
- const ErrorCallback& on_failure) {
+ const base::Callback<void(const base::ListValue&, ErrorPtr error)>&
+ callback) {
fetch_commands_request_sent_ = true;
fetch_commands_request_queued_ = false;
DoCloudRequest(
HttpClient::Method::kGet,
GetServiceURL("commands/queue", {{"deviceId", GetSettings().cloud_id}}),
- nullptr, base::Bind(&DeviceRegistrationInfo::OnFetchCommandsSuccess,
- AsWeakPtr(), on_success),
- base::Bind(&DeviceRegistrationInfo::OnFetchCommandsError, AsWeakPtr(),
- on_failure));
+ nullptr, base::Bind(&DeviceRegistrationInfo::OnFetchCommandsDone,
+ AsWeakPtr(), callback));
}
void DeviceRegistrationInfo::FetchAndPublishCommands() {
@@ -1067,12 +1007,14 @@
}
FetchCommands(base::Bind(&DeviceRegistrationInfo::PublishCommands,
- weak_factory_.GetWeakPtr()),
- base::Bind(&IgnoreCloudError));
+ weak_factory_.GetWeakPtr()));
}
void DeviceRegistrationInfo::ProcessInitialCommandList(
- const base::ListValue& commands) {
+ const base::ListValue& commands,
+ ErrorPtr error) {
+ if (error)
+ return;
for (const base::Value* command : commands) {
const base::DictionaryValue* command_dict{nullptr};
if (!command->GetAsDictionary(&command_dict)) {
@@ -1098,8 +1040,7 @@
// TODO(wiley) We could consider handling this error case more gracefully.
DoCloudRequest(HttpClient::Method::kPut,
GetServiceURL("commands/" + command_id), cmd_copy.get(),
- base::Bind(&IgnoreCloudResult),
- base::Bind(&IgnoreCloudError));
+ base::Bind(&IgnoreCloudResult));
} else {
// Normal command, publish it to local clients.
PublishCommand(*command_dict);
@@ -1107,7 +1048,10 @@
}
}
-void DeviceRegistrationInfo::PublishCommands(const base::ListValue& commands) {
+void DeviceRegistrationInfo::PublishCommands(const base::ListValue& commands,
+ ErrorPtr error) {
+ if (error)
+ return;
for (const base::Value* command : commands) {
const base::DictionaryValue* command_dict{nullptr};
if (!command->GetAsDictionary(&command_dict)) {
@@ -1188,28 +1132,26 @@
body.Set("patches", patches.release());
device_state_update_pending_ = true;
- DoCloudRequest(
- HttpClient::Method::kPost, GetDeviceURL("patchState"), &body,
- base::Bind(&DeviceRegistrationInfo::OnPublishStateSuccess, AsWeakPtr(),
- update_id),
- base::Bind(&DeviceRegistrationInfo::OnPublishStateError, AsWeakPtr()));
+ DoCloudRequest(HttpClient::Method::kPost, GetDeviceURL("patchState"), &body,
+ base::Bind(&DeviceRegistrationInfo::OnPublishStateDone,
+ AsWeakPtr(), update_id));
}
-void DeviceRegistrationInfo::OnPublishStateSuccess(
+void DeviceRegistrationInfo::OnPublishStateDone(
StateChangeQueueInterface::UpdateID update_id,
- const base::DictionaryValue& reply) {
+ const base::DictionaryValue& reply,
+ ErrorPtr error) {
device_state_update_pending_ = false;
+ if (error) {
+ LOG(ERROR) << "Permanent failure while trying to update device state";
+ return;
+ }
state_manager_->NotifyStateUpdatedOnServer(update_id);
// See if there were more pending state updates since the previous request
// had been sent out.
PublishStateUpdates();
}
-void DeviceRegistrationInfo::OnPublishStateError(ErrorPtr error) {
- LOG(ERROR) << "Permanent failure while trying to update device state";
- device_state_update_pending_ = false;
-}
-
void DeviceRegistrationInfo::SetGcdState(GcdState new_state) {
VLOG_IF(1, new_state != gcd_state_) << "Changing registration status to "
<< EnumToString(new_state);
@@ -1223,8 +1165,7 @@
if (!HaveRegistrationCredentials() || !connected_to_cloud_)
return;
- UpdateDeviceResource(base::Bind(&base::DoNothing),
- base::Bind(&IgnoreCloudError));
+ UpdateDeviceResource(base::Bind(&IgnoreCloudError));
}
void DeviceRegistrationInfo::OnStateChanged() {
@@ -1257,8 +1198,9 @@
// the moment of the last poll and the time we successfully told the server
// to send new commands over the new notification channel.
UpdateDeviceResource(
- base::Bind(&DeviceRegistrationInfo::FetchAndPublishCommands, AsWeakPtr()),
- base::Bind(&IgnoreCloudError));
+ base::Bind(&IgnoreCloudErrorWithCallback,
+ base::Bind(&DeviceRegistrationInfo::FetchAndPublishCommands,
+ AsWeakPtr())));
}
void DeviceRegistrationInfo::OnDisconnected() {
@@ -1269,15 +1211,13 @@
pull_channel_->UpdatePullInterval(
base::TimeDelta::FromSeconds(kPollingPeriodSeconds));
current_notification_channel_ = pull_channel_.get();
- UpdateDeviceResource(base::Bind(&base::DoNothing),
- base::Bind(&IgnoreCloudError));
+ UpdateDeviceResource(base::Bind(&IgnoreCloudError));
}
void DeviceRegistrationInfo::OnPermanentFailure() {
LOG(ERROR) << "Failed to establish notification channel.";
notification_channel_starting_ = false;
RefreshAccessToken(
- base::Bind(&base::DoNothing),
base::Bind(&DeviceRegistrationInfo::CheckAccessTokenError, AsWeakPtr()));
}
diff --git a/libweave/src/device_registration_info.h b/libweave/src/device_registration_info.h
index a7f79b6..d389290 100644
--- a/libweave/src/device_registration_info.h
+++ b/libweave/src/device_registration_info.h
@@ -50,8 +50,9 @@
class DeviceRegistrationInfo : public NotificationDelegate,
public CloudCommandUpdateInterface {
public:
- using CloudRequestCallback =
- base::Callback<void(const base::DictionaryValue& response)>;
+ using CloudRequestDoneCallback =
+ base::Callback<void(const base::DictionaryValue& response,
+ ErrorPtr error)>;
DeviceRegistrationInfo(const std::shared_ptr<CommandManager>& command_manager,
const std::shared_ptr<StateManager>& state_manager,
@@ -65,8 +66,7 @@
void AddGcdStateChangedCallback(
const Device::GcdStateChangedCallback& callback);
void RegisterDevice(const std::string& ticket_id,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback);
+ const DoneCallback& callback);
void UpdateDeviceInfo(const std::string& name,
const std::string& description,
@@ -81,8 +81,7 @@
const std::string& service_url,
ErrorPtr* error);
- void GetDeviceInfo(const CloudRequestCallback& success_callback,
- const ErrorCallback& error_callback);
+ void GetDeviceInfo(const CloudRequestDoneCallback& callback);
// Returns the GCD service request URL. If |subpath| is specified, it is
// appended to the base URL which is normally
@@ -120,8 +119,7 @@
// Updates a command (override from CloudCommandUpdateInterface).
void UpdateCommand(const std::string& command_id,
const base::DictionaryValue& command_patch,
- const base::Closure& on_success,
- const base::Closure& on_error) override;
+ const DoneCallback& callback) override;
// TODO(vitalybuka): remove getters and pass config to dependent code.
const Config::Settings& GetSettings() const { return config_->GetSettings(); }
@@ -143,22 +141,17 @@
// Initiates the connection to the cloud server.
// Device will do required start up chores and then start to listen
// to new commands.
- void ConnectToCloud();
+ void ConnectToCloud(ErrorPtr error);
// Notification called when ConnectToCloud() succeeds.
- void OnConnectedToCloud();
+ void OnConnectedToCloud(ErrorPtr error);
// Forcibly refreshes the access token.
- void RefreshAccessToken(const base::Closure& success_callback,
- const ErrorCallback& error_callback);
+ void RefreshAccessToken(const DoneCallback& callback);
// Callbacks for RefreshAccessToken().
- void OnRefreshAccessTokenSuccess(
- const std::shared_ptr<base::Closure>& success_callback,
- const std::shared_ptr<ErrorCallback>& error_callback,
- const provider::HttpClient::Response& response);
- void OnRefreshAccessTokenError(
- const std::shared_ptr<base::Closure>& success_callback,
- const std::shared_ptr<ErrorCallback>& error_callback,
+ void OnRefreshAccessTokenDone(
+ const DoneCallback& callback,
+ std::unique_ptr<provider::HttpClient::Response> response,
ErrorPtr error);
// Parse the OAuth response, and sets registration status to
@@ -179,40 +172,36 @@
void DoCloudRequest(provider::HttpClient::Method method,
const std::string& url,
const base::DictionaryValue* body,
- const CloudRequestCallback& success_callback,
- const ErrorCallback& error_callback);
+ const CloudRequestDoneCallback& callback);
// Helper for DoCloudRequest().
struct CloudRequestData {
provider::HttpClient::Method method;
std::string url;
std::string body;
- CloudRequestCallback success_callback;
- ErrorCallback error_callback;
+ CloudRequestDoneCallback callback;
};
void SendCloudRequest(const std::shared_ptr<const CloudRequestData>& data);
- void OnCloudRequestSuccess(
+ void OnCloudRequestDone(
const std::shared_ptr<const CloudRequestData>& data,
- const provider::HttpClient::Response& response);
- void OnCloudRequestError(const std::shared_ptr<const CloudRequestData>& data,
- ErrorPtr error);
+ std::unique_ptr<provider::HttpClient::Response> response,
+ ErrorPtr error);
void RetryCloudRequest(const std::shared_ptr<const CloudRequestData>& data);
void OnAccessTokenRefreshed(
- const std::shared_ptr<const CloudRequestData>& data);
- void OnAccessTokenError(const std::shared_ptr<const CloudRequestData>& data,
- ErrorPtr error);
+ const std::shared_ptr<const CloudRequestData>& data,
+ ErrorPtr error);
void CheckAccessTokenError(ErrorPtr error);
- void UpdateDeviceResource(const base::Closure& on_success,
- const ErrorCallback& on_failure);
+ void UpdateDeviceResource(const DoneCallback& callback);
void StartQueuedUpdateDeviceResource();
- // Success/failure callbacks for UpdateDeviceResource().
- void OnUpdateDeviceResourceSuccess(const base::DictionaryValue& device_info);
+ void OnUpdateDeviceResourceDone(const base::DictionaryValue& device_info,
+ ErrorPtr error);
void OnUpdateDeviceResourceError(ErrorPtr error);
// Callback from GetDeviceInfo() to retrieve the device resource timestamp
// and retry UpdateDeviceResource() call.
- void OnDeviceInfoRetrieved(const base::DictionaryValue& device_info);
+ void OnDeviceInfoRetrieved(const base::DictionaryValue& device_info,
+ ErrorPtr error);
// Extracts the timestamp from the device resource and sets it to
// |last_device_resource_updated_timestamp_|.
@@ -221,13 +210,11 @@
bool UpdateDeviceInfoTimestamp(const base::DictionaryValue& device_info);
void FetchCommands(
- const base::Callback<void(const base::ListValue&)>& on_success,
- const ErrorCallback& on_failure);
- // Success/failure callbacks for FetchCommands().
- void OnFetchCommandsSuccess(
- const base::Callback<void(const base::ListValue&)>& callback,
- const base::DictionaryValue& json);
- void OnFetchCommandsError(const ErrorCallback& callback, ErrorPtr error);
+ const base::Callback<void(const base::ListValue&, ErrorPtr)>& callback);
+ void OnFetchCommandsDone(
+ const base::Callback<void(const base::ListValue&, ErrorPtr)>& callback,
+ const base::DictionaryValue& json,
+ ErrorPtr);
// Called when FetchCommands completes (with either success or error).
// This method reschedules any pending/queued fetch requests.
void OnFetchCommandsReturned();
@@ -235,9 +222,10 @@
// Processes the command list that is fetched from the server on connection.
// Aborts commands which are in transitional states and publishes queued
// commands which are queued.
- void ProcessInitialCommandList(const base::ListValue& commands);
+ void ProcessInitialCommandList(const base::ListValue& commands,
+ ErrorPtr error);
- void PublishCommands(const base::ListValue& commands);
+ void PublishCommands(const base::ListValue& commands, ErrorPtr error);
void PublishCommand(const base::DictionaryValue& command);
// Helper function to pull the pending command list from the server using
@@ -245,8 +233,9 @@
void FetchAndPublishCommands();
void PublishStateUpdates();
- void OnPublishStateSuccess(StateChangeQueueInterface::UpdateID update_id,
- const base::DictionaryValue& reply);
+ void OnPublishStateDone(StateChangeQueueInterface::UpdateID update_id,
+ const base::DictionaryValue& reply,
+ ErrorPtr error);
void OnPublishStateError(ErrorPtr error);
// If unrecoverable error occurred (e.g. error parsing command instance),
@@ -275,21 +264,22 @@
// Wipes out the device registration information and stops server connections.
void MarkDeviceUnregistered();
- struct RegisterCallbacks;
- void RegisterDeviceError(const std::shared_ptr<RegisterCallbacks>& callbacks,
- ErrorPtr error);
+ void RegisterDeviceError(const DoneCallback& callback, ErrorPtr error);
void RegisterDeviceOnTicketSent(
const std::string& ticket_id,
- const std::shared_ptr<RegisterCallbacks>& callbacks,
- const provider::HttpClient::Response& response);
+ const DoneCallback& callback,
+ std::unique_ptr<provider::HttpClient::Response> response,
+ ErrorPtr error);
void RegisterDeviceOnTicketFinalized(
- const std::shared_ptr<RegisterCallbacks>& callbacks,
- const provider::HttpClient::Response& response);
+ const DoneCallback& callback,
+ std::unique_ptr<provider::HttpClient::Response> response,
+ ErrorPtr error);
void RegisterDeviceOnAuthCodeSent(
const std::string& cloud_id,
const std::string& robot_account,
- const std::shared_ptr<RegisterCallbacks>& callbacks,
- const provider::HttpClient::Response& response);
+ const DoneCallback& callback,
+ std::unique_ptr<provider::HttpClient::Response> response,
+ ErrorPtr error);
// Transient data
std::string access_token_;
@@ -327,13 +317,12 @@
// another one was in flight.
bool fetch_commands_request_queued_{false};
- using ResourceUpdateCallbackList =
- std::vector<std::pair<base::Closure, ErrorCallback>>;
- // Success/error callbacks for device resource update request currently in
- // flight to the cloud server.
+ using ResourceUpdateCallbackList = std::vector<DoneCallback>;
+ // Callbacks for device resource update request currently in flight to the
+ // cloud server.
ResourceUpdateCallbackList in_progress_resource_update_callbacks_;
- // Success/error callbacks for device resource update requests queued while
- // another request is in flight to the cloud server.
+ // Callbacks for device resource update requests queued while another request
+ // is in flight to the cloud server.
ResourceUpdateCallbackList queued_resource_update_callbacks_;
std::unique_ptr<NotificationChannel> primary_notification_channel_;
diff --git a/libweave/src/device_registration_info_unittest.cc b/libweave/src/device_registration_info_unittest.cc
index 6d00a51..b086f8b 100644
--- a/libweave/src/device_registration_info_unittest.cc
+++ b/libweave/src/device_registration_info_unittest.cc
@@ -166,18 +166,19 @@
}
void PublishCommands(const base::ListValue& commands) {
- return dev_reg_->PublishCommands(commands);
+ dev_reg_->PublishCommands(commands, nullptr);
}
bool RefreshAccessToken(ErrorPtr* error) const {
bool succeeded = false;
- auto on_success = [&succeeded]() { succeeded = true; };
- auto on_failure = [&error](ErrorPtr in_error) {
- if (error)
+ auto callback = [&succeeded, &error](ErrorPtr in_error) {
+ if (error) {
*error = std::move(in_error);
+ return;
+ }
+ succeeded = true;
};
- dev_reg_->RefreshAccessToken(base::Bind(on_success),
- base::Bind(on_failure));
+ dev_reg_->RefreshAccessToken(base::Bind(callback));
return succeeded;
}
@@ -236,10 +237,10 @@
EXPECT_CALL(
http_client_,
SendRequest(HttpClient::Method::kPost, dev_reg_->GetOAuthURL("token"),
- HttpClient::Headers{GetFormHeader()}, _, _, _))
+ HttpClient::Headers{GetFormHeader()}, _, _))
.WillOnce(WithArgs<3, 4>(Invoke([](
const std::string& data,
- const HttpClient::SuccessCallback& callback) {
+ const HttpClient::SendRequestCallback& callback) {
EXPECT_EQ("refresh_token", GetFormField(data, "grant_type"));
EXPECT_EQ(test_data::kRefreshToken,
GetFormField(data, "refresh_token"));
@@ -251,7 +252,7 @@
json.SetString("access_token", test_data::kAccessToken);
json.SetInteger("expires_in", 3600);
- callback.Run(*ReplyWithJson(200, json));
+ callback.Run(ReplyWithJson(200, json), nullptr);
})));
EXPECT_TRUE(RefreshAccessToken(nullptr));
@@ -265,10 +266,10 @@
EXPECT_CALL(
http_client_,
SendRequest(HttpClient::Method::kPost, dev_reg_->GetOAuthURL("token"),
- HttpClient::Headers{GetFormHeader()}, _, _, _))
+ HttpClient::Headers{GetFormHeader()}, _, _))
.WillOnce(WithArgs<3, 4>(Invoke([](
const std::string& data,
- const HttpClient::SuccessCallback& callback) {
+ const HttpClient::SendRequestCallback& callback) {
EXPECT_EQ("refresh_token", GetFormField(data, "grant_type"));
EXPECT_EQ(test_data::kRefreshToken,
GetFormField(data, "refresh_token"));
@@ -278,7 +279,7 @@
base::DictionaryValue json;
json.SetString("error", "unable_to_authenticate");
- callback.Run(*ReplyWithJson(400, json));
+ callback.Run(ReplyWithJson(400, json), nullptr);
})));
ErrorPtr error;
@@ -294,10 +295,10 @@
EXPECT_CALL(
http_client_,
SendRequest(HttpClient::Method::kPost, dev_reg_->GetOAuthURL("token"),
- HttpClient::Headers{GetFormHeader()}, _, _, _))
+ HttpClient::Headers{GetFormHeader()}, _, _))
.WillOnce(WithArgs<3, 4>(Invoke([](
const std::string& data,
- const HttpClient::SuccessCallback& callback) {
+ const HttpClient::SendRequestCallback& callback) {
EXPECT_EQ("refresh_token", GetFormField(data, "grant_type"));
EXPECT_EQ(test_data::kRefreshToken,
GetFormField(data, "refresh_token"));
@@ -307,7 +308,7 @@
base::DictionaryValue json;
json.SetString("error", "invalid_grant");
- callback.Run(*ReplyWithJson(400, json));
+ callback.Run(ReplyWithJson(400, json), nullptr);
})));
ErrorPtr error;
@@ -320,30 +321,31 @@
ReloadSettings();
SetAccessToken();
- EXPECT_CALL(http_client_,
- SendRequest(HttpClient::Method::kGet, dev_reg_->GetDeviceURL(),
- HttpClient::Headers{GetAuthHeader(), GetJsonHeader()},
- _, _, _))
+ EXPECT_CALL(
+ http_client_,
+ SendRequest(HttpClient::Method::kGet, dev_reg_->GetDeviceURL(),
+ HttpClient::Headers{GetAuthHeader(), GetJsonHeader()}, _, _))
.WillOnce(WithArgs<3, 4>(
Invoke([](const std::string& data,
- const HttpClient::SuccessCallback& callback) {
+ const HttpClient::SendRequestCallback& callback) {
base::DictionaryValue json;
json.SetString("channel.supportedType", "xmpp");
json.SetString("deviceKind", "vendor");
json.SetString("id", test_data::kDeviceId);
json.SetString("kind", "clouddevices#device");
- callback.Run(*ReplyWithJson(200, json));
+ callback.Run(ReplyWithJson(200, json), nullptr);
})));
bool succeeded = false;
- auto on_success = [&succeeded, this](const base::DictionaryValue& info) {
+ auto callback = [&succeeded, this](const base::DictionaryValue& info,
+ ErrorPtr error) {
+ EXPECT_FALSE(error);
std::string id;
EXPECT_TRUE(info.GetString("id", &id));
EXPECT_EQ(test_data::kDeviceId, id);
succeeded = true;
};
- auto on_failure = [](ErrorPtr error) { FAIL() << "Should not be called"; };
- dev_reg_->GetDeviceInfo(base::Bind(on_success), base::Bind(on_failure));
+ dev_reg_->GetDeviceInfo(base::Bind(callback));
EXPECT_TRUE(succeeded);
}
@@ -386,10 +388,10 @@
EXPECT_CALL(http_client_,
SendRequest(HttpClient::Method::kPatch,
ticket_url + "?key=" + test_data::kApiKey,
- HttpClient::Headers{GetJsonHeader()}, _, _, _))
+ HttpClient::Headers{GetJsonHeader()}, _, _))
.WillOnce(WithArgs<3, 4>(Invoke([](
const std::string& data,
- const HttpClient::SuccessCallback& callback) {
+ const HttpClient::SendRequestCallback& callback) {
auto json = test::CreateDictionaryValue(data);
EXPECT_NE(nullptr, json.get());
std::string value;
@@ -449,15 +451,15 @@
device_draft->SetString("kind", "clouddevices#device");
json_resp.Set("deviceDraft", device_draft);
- callback.Run(*ReplyWithJson(200, json_resp));
+ callback.Run(ReplyWithJson(200, json_resp), nullptr);
})));
EXPECT_CALL(http_client_,
SendRequest(HttpClient::Method::kPost,
ticket_url + "/finalize?key=" + test_data::kApiKey,
- HttpClient::Headers{}, _, _, _))
- .WillOnce(
- WithArgs<4>(Invoke([](const HttpClient::SuccessCallback& callback) {
+ HttpClient::Headers{}, _, _))
+ .WillOnce(WithArgs<4>(
+ Invoke([](const HttpClient::SendRequestCallback& callback) {
base::DictionaryValue json;
json.SetString("id", test_data::kClaimTicketId);
json.SetString("kind", "clouddevices#registrationTicket");
@@ -469,16 +471,16 @@
json.SetString("robotAccountEmail", test_data::kRobotAccountEmail);
json.SetString("robotAccountAuthorizationCode",
test_data::kRobotAccountAuthCode);
- callback.Run(*ReplyWithJson(200, json));
+ callback.Run(ReplyWithJson(200, json), nullptr);
})));
EXPECT_CALL(
http_client_,
SendRequest(HttpClient::Method::kPost, dev_reg_->GetOAuthURL("token"),
- HttpClient::Headers{GetFormHeader()}, _, _, _))
+ HttpClient::Headers{GetFormHeader()}, _, _))
.WillOnce(WithArgs<3, 4>(Invoke([](
const std::string& data,
- const HttpClient::SuccessCallback& callback) {
+ const HttpClient::SendRequestCallback& callback) {
EXPECT_EQ("authorization_code", GetFormField(data, "grant_type"));
EXPECT_EQ(test_data::kRobotAccountAuthCode, GetFormField(data, "code"));
EXPECT_EQ(test_data::kClientId, GetFormField(data, "client_id"));
@@ -495,12 +497,13 @@
json.SetString("refresh_token", test_data::kRefreshToken);
json.SetInteger("expires_in", 3600);
- callback.Run(*ReplyWithJson(200, json));
+ callback.Run(ReplyWithJson(200, json), nullptr);
})));
bool done = false;
dev_reg_->RegisterDevice(
- test_data::kClaimTicketId, base::Bind([this, &done]() {
+ test_data::kClaimTicketId, base::Bind([this, &done](ErrorPtr error) {
+ EXPECT_FALSE(error);
done = true;
task_runner_.Break();
EXPECT_EQ(GcdState::kConnecting, GetGcdState());
@@ -511,8 +514,7 @@
dev_reg_->GetSettings().refresh_token);
EXPECT_EQ(test_data::kRobotAccountEmail,
dev_reg_->GetSettings().robot_account);
- }),
- base::Bind([](ErrorPtr error) { ADD_FAILURE(); }));
+ }));
task_runner_.Run();
EXPECT_TRUE(done);
}
@@ -573,51 +575,51 @@
};
TEST_F(DeviceRegistrationInfoUpdateCommandTest, SetProgress) {
- EXPECT_CALL(http_client_,
- SendRequest(HttpClient::Method::kPatch, command_url_,
- HttpClient::Headers{GetAuthHeader(), GetJsonHeader()},
- _, _, _))
+ EXPECT_CALL(
+ http_client_,
+ SendRequest(HttpClient::Method::kPatch, command_url_,
+ HttpClient::Headers{GetAuthHeader(), GetJsonHeader()}, _, _))
.WillOnce(WithArgs<3, 4>(Invoke([](
const std::string& data,
- const HttpClient::SuccessCallback& callback) {
+ const HttpClient::SendRequestCallback& callback) {
EXPECT_JSON_EQ((R"({"state":"inProgress","progress":{"progress":18}})"),
*CreateDictionaryValue(data));
base::DictionaryValue json;
- callback.Run(*ReplyWithJson(200, json));
+ callback.Run(ReplyWithJson(200, json), nullptr);
})));
EXPECT_TRUE(command_->SetProgress(*CreateDictionaryValue("{'progress':18}"),
nullptr));
}
TEST_F(DeviceRegistrationInfoUpdateCommandTest, Complete) {
- EXPECT_CALL(http_client_,
- SendRequest(HttpClient::Method::kPatch, command_url_,
- HttpClient::Headers{GetAuthHeader(), GetJsonHeader()},
- _, _, _))
+ EXPECT_CALL(
+ http_client_,
+ SendRequest(HttpClient::Method::kPatch, command_url_,
+ HttpClient::Headers{GetAuthHeader(), GetJsonHeader()}, _, _))
.WillOnce(WithArgs<3, 4>(Invoke([](
const std::string& data,
- const HttpClient::SuccessCallback& callback) {
+ const HttpClient::SendRequestCallback& callback) {
EXPECT_JSON_EQ(R"({"state":"done", "results":{"status":"Ok"}})",
*CreateDictionaryValue(data));
base::DictionaryValue json;
- callback.Run(*ReplyWithJson(200, json));
+ callback.Run(ReplyWithJson(200, json), nullptr);
})));
EXPECT_TRUE(
command_->Complete(*CreateDictionaryValue("{'status': 'Ok'}"), nullptr));
}
TEST_F(DeviceRegistrationInfoUpdateCommandTest, Cancel) {
- EXPECT_CALL(http_client_,
- SendRequest(HttpClient::Method::kPatch, command_url_,
- HttpClient::Headers{GetAuthHeader(), GetJsonHeader()},
- _, _, _))
+ EXPECT_CALL(
+ http_client_,
+ SendRequest(HttpClient::Method::kPatch, command_url_,
+ HttpClient::Headers{GetAuthHeader(), GetJsonHeader()}, _, _))
.WillOnce(WithArgs<3, 4>(Invoke([](
const std::string& data,
- const HttpClient::SuccessCallback& callback) {
+ const HttpClient::SendRequestCallback& callback) {
EXPECT_JSON_EQ(R"({"state":"cancelled"})",
*CreateDictionaryValue(data));
base::DictionaryValue json;
- callback.Run(*ReplyWithJson(200, json));
+ callback.Run(ReplyWithJson(200, json), nullptr);
})));
EXPECT_TRUE(command_->Cancel(nullptr));
}
diff --git a/libweave/src/notification/xmpp_channel.cc b/libweave/src/notification/xmpp_channel.cc
index 0d22286..c88ba45 100644
--- a/libweave/src/notification/xmpp_channel.cc
+++ b/libweave/src/notification/xmpp_channel.cc
@@ -106,10 +106,12 @@
}
}
-void XmppChannel::OnMessageRead(size_t size) {
+void XmppChannel::OnMessageRead(size_t size, ErrorPtr error) {
+ read_pending_ = false;
+ if (error)
+ return Restart();
std::string msg(read_socket_data_.data(), size);
VLOG(2) << "Received XMPP packet: '" << msg << "'";
- read_pending_ = false;
if (!size)
return Restart();
@@ -286,14 +288,24 @@
LOG(INFO) << "Starting XMPP connection to " << kDefaultXmppHost << ":"
<< kDefaultXmppPort;
- network_->OpenSslSocket(
- kDefaultXmppHost, kDefaultXmppPort,
- base::Bind(&XmppChannel::OnSslSocketReady,
- task_ptr_factory_.GetWeakPtr()),
- base::Bind(&XmppChannel::OnSslError, task_ptr_factory_.GetWeakPtr()));
+ network_->OpenSslSocket(kDefaultXmppHost, kDefaultXmppPort,
+ base::Bind(&XmppChannel::OnSslSocketReady,
+ task_ptr_factory_.GetWeakPtr()));
}
-void XmppChannel::OnSslSocketReady(std::unique_ptr<Stream> stream) {
+void XmppChannel::OnSslSocketReady(std::unique_ptr<Stream> stream,
+ ErrorPtr error) {
+ if (error) {
+ LOG(ERROR) << "TLS handshake failed. Restarting XMPP connection";
+ backoff_entry_.InformOfRequest(false);
+
+ LOG(INFO) << "Delaying connection to XMPP server for "
+ << backoff_entry_.GetTimeUntilRelease();
+ return task_runner_->PostDelayedTask(
+ FROM_HERE, base::Bind(&XmppChannel::CreateSslSocket,
+ task_ptr_factory_.GetWeakPtr()),
+ backoff_entry_.GetTimeUntilRelease());
+ }
CHECK(XmppState::kConnecting == state_);
backoff_entry_.InformOfRequest(true);
stream_ = std::move(stream);
@@ -302,18 +314,6 @@
ScheduleRegularPing();
}
-void XmppChannel::OnSslError(ErrorPtr error) {
- LOG(ERROR) << "TLS handshake failed. Restarting XMPP connection";
- backoff_entry_.InformOfRequest(false);
-
- LOG(INFO) << "Delaying connection to XMPP server for "
- << backoff_entry_.GetTimeUntilRelease();
- task_runner_->PostDelayedTask(
- FROM_HERE,
- base::Bind(&XmppChannel::CreateSslSocket, task_ptr_factory_.GetWeakPtr()),
- backoff_entry_.GetTimeUntilRelease());
-}
-
void XmppChannel::SendMessage(const std::string& message) {
CHECK(stream_) << "No XMPP socket stream available";
if (write_pending_) {
@@ -327,13 +327,13 @@
write_pending_ = true;
stream_->Write(
write_socket_data_.data(), write_socket_data_.size(),
- base::Bind(&XmppChannel::OnMessageSent, task_ptr_factory_.GetWeakPtr()),
- base::Bind(&XmppChannel::OnWriteError, task_ptr_factory_.GetWeakPtr()));
+ base::Bind(&XmppChannel::OnMessageSent, task_ptr_factory_.GetWeakPtr()));
}
-void XmppChannel::OnMessageSent() {
- ErrorPtr error;
+void XmppChannel::OnMessageSent(ErrorPtr error) {
write_pending_ = false;
+ if (error)
+ return Restart();
if (queued_write_data_.empty()) {
WaitForMessage();
} else {
@@ -348,18 +348,7 @@
read_pending_ = true;
stream_->Read(
read_socket_data_.data(), read_socket_data_.size(),
- base::Bind(&XmppChannel::OnMessageRead, task_ptr_factory_.GetWeakPtr()),
- base::Bind(&XmppChannel::OnReadError, task_ptr_factory_.GetWeakPtr()));
-}
-
-void XmppChannel::OnReadError(ErrorPtr error) {
- read_pending_ = false;
- Restart();
-}
-
-void XmppChannel::OnWriteError(ErrorPtr error) {
- write_pending_ = false;
- Restart();
+ base::Bind(&XmppChannel::OnMessageRead, task_ptr_factory_.GetWeakPtr()));
}
std::string XmppChannel::GetName() const {
diff --git a/libweave/src/notification/xmpp_channel.h b/libweave/src/notification/xmpp_channel.h
index 814b2a5..a40eca9 100644
--- a/libweave/src/notification/xmpp_channel.h
+++ b/libweave/src/notification/xmpp_channel.h
@@ -97,15 +97,12 @@
void RestartXmppStream();
void CreateSslSocket();
- void OnSslSocketReady(std::unique_ptr<Stream> stream);
- void OnSslError(ErrorPtr error);
+ void OnSslSocketReady(std::unique_ptr<Stream> stream, ErrorPtr error);
void WaitForMessage();
- void OnMessageRead(size_t size);
- void OnMessageSent();
- void OnReadError(ErrorPtr error);
- void OnWriteError(ErrorPtr error);
+ void OnMessageRead(size_t size, ErrorPtr error);
+ void OnMessageSent(ErrorPtr error);
void Restart();
void CloseStream();
diff --git a/libweave/src/notification/xmpp_channel_unittest.cc b/libweave/src/notification/xmpp_channel_unittest.cc
index c6e0be1..6c334dd 100644
--- a/libweave/src/notification/xmpp_channel_unittest.cc
+++ b/libweave/src/notification/xmpp_channel_unittest.cc
@@ -88,8 +88,9 @@
stream_{new test::FakeStream{task_runner_}},
fake_stream_{stream_.get()} {}
- void Connect(const base::Callback<void(std::unique_ptr<Stream>)>& callback) {
- callback.Run(std::move(stream_));
+ void Connect(const base::Callback<void(std::unique_ptr<Stream>,
+ ErrorPtr error)>& callback) {
+ callback.Run(std::move(stream_), nullptr);
}
XmppState state() const { return state_; }
@@ -121,7 +122,7 @@
class XmppChannelTest : public ::testing::Test {
protected:
XmppChannelTest() {
- EXPECT_CALL(network_, OpenSslSocket("talk.google.com", 5223, _, _))
+ EXPECT_CALL(network_, OpenSslSocket("talk.google.com", 5223, _))
.WillOnce(
WithArgs<2>(Invoke(&xmpp_client_, &FakeXmppChannel::Connect)));
}
diff --git a/libweave/src/privet/cloud_delegate.cc b/libweave/src/privet/cloud_delegate.cc
index 3df5b86..81b7e33 100644
--- a/libweave/src/privet/cloud_delegate.cc
+++ b/libweave/src/privet/cloud_delegate.cc
@@ -154,8 +154,7 @@
void AddCommand(const base::DictionaryValue& command,
const UserInfo& user_info,
- const CommandSuccessCallback& success_callback,
- const ErrorCallback& error_callback) override {
+ const CommandDoneCallback& callback) override {
CHECK(user_info.scope() != AuthScope::kNone);
CHECK_NE(user_info.user_id(), 0u);
@@ -166,44 +165,41 @@
Error::AddToPrintf(&error, FROM_HERE, errors::kDomain,
errors::kInvalidParams, "Invalid role: '%s'",
str_scope.c_str());
- return error_callback.Run(std::move(error));
+ return callback.Run({}, std::move(error));
}
std::string id;
if (!command_manager_->AddCommand(command, role, &id, &error))
- return error_callback.Run(std::move(error));
+ return callback.Run({}, std::move(error));
command_owners_[id] = user_info.user_id();
- success_callback.Run(*command_manager_->FindCommand(id)->ToJson());
+ callback.Run(*command_manager_->FindCommand(id)->ToJson(), nullptr);
}
void GetCommand(const std::string& id,
const UserInfo& user_info,
- const CommandSuccessCallback& success_callback,
- const ErrorCallback& error_callback) override {
+ const CommandDoneCallback& callback) override {
CHECK(user_info.scope() != AuthScope::kNone);
ErrorPtr error;
auto command = GetCommandInternal(id, user_info, &error);
if (!command)
- return error_callback.Run(std::move(error));
- success_callback.Run(*command->ToJson());
+ return callback.Run({}, std::move(error));
+ callback.Run(*command->ToJson(), nullptr);
}
void CancelCommand(const std::string& id,
const UserInfo& user_info,
- const CommandSuccessCallback& success_callback,
- const ErrorCallback& error_callback) override {
+ const CommandDoneCallback& callback) override {
CHECK(user_info.scope() != AuthScope::kNone);
ErrorPtr error;
auto command = GetCommandInternal(id, user_info, &error);
if (!command || !command->Cancel(&error))
- return error_callback.Run(std::move(error));
- success_callback.Run(*command->ToJson());
+ return callback.Run({}, std::move(error));
+ callback.Run(*command->ToJson(), nullptr);
}
void ListCommands(const UserInfo& user_info,
- const CommandSuccessCallback& success_callback,
- const ErrorCallback& error_callback) override {
+ const CommandDoneCallback& callback) override {
CHECK(user_info.scope() != AuthScope::kNone);
base::ListValue list_value;
@@ -218,7 +214,7 @@
base::DictionaryValue commands_json;
commands_json.Set("commands", list_value.DeepCopy());
- success_callback.Run(commands_json);
+ callback.Run(commands_json, nullptr);
}
private:
@@ -285,29 +281,27 @@
}
device_->RegisterDevice(
- ticket_id, base::Bind(&CloudDelegateImpl::RegisterDeviceSuccess,
- setup_weak_factory_.GetWeakPtr()),
- base::Bind(&CloudDelegateImpl::RegisterDeviceError,
+ ticket_id,
+ base::Bind(&CloudDelegateImpl::RegisterDeviceDone,
setup_weak_factory_.GetWeakPtr(), ticket_id, deadline));
}
- void RegisterDeviceSuccess() {
+ void RegisterDeviceDone(const std::string& ticket_id,
+ const base::Time& deadline,
+ ErrorPtr error) {
+ if (error) {
+ // Registration failed. Retry with backoff.
+ backoff_entry_.InformOfRequest(false);
+ return task_runner_->PostDelayedTask(
+ FROM_HERE,
+ base::Bind(&CloudDelegateImpl::CallManagerRegisterDevice,
+ setup_weak_factory_.GetWeakPtr(), ticket_id, deadline),
+ backoff_entry_.GetTimeUntilRelease());
+ }
backoff_entry_.InformOfRequest(true);
setup_state_ = SetupState(SetupState::kSuccess);
}
- void RegisterDeviceError(const std::string& ticket_id,
- const base::Time& deadline,
- ErrorPtr error) {
- // Registration failed. Retry with backoff.
- backoff_entry_.InformOfRequest(false);
- task_runner_->PostDelayedTask(
- FROM_HERE,
- base::Bind(&CloudDelegateImpl::CallManagerRegisterDevice,
- setup_weak_factory_.GetWeakPtr(), ticket_id, deadline),
- backoff_entry_.GetTimeUntilRelease());
- }
-
CommandInstance* GetCommandInternal(const std::string& command_id,
const UserInfo& user_info,
ErrorPtr* error) const {
diff --git a/libweave/src/privet/cloud_delegate.h b/libweave/src/privet/cloud_delegate.h
index 74456d3..05ba8a6 100644
--- a/libweave/src/privet/cloud_delegate.h
+++ b/libweave/src/privet/cloud_delegate.h
@@ -39,8 +39,9 @@
CloudDelegate();
virtual ~CloudDelegate();
- using CommandSuccessCallback =
- base::Callback<void(const base::DictionaryValue& commands)>;
+ using CommandDoneCallback =
+ base::Callback<void(const base::DictionaryValue& commands,
+ ErrorPtr error)>;
class Observer {
public:
@@ -107,25 +108,21 @@
// Adds command created from the given JSON representation.
virtual void AddCommand(const base::DictionaryValue& command,
const UserInfo& user_info,
- const CommandSuccessCallback& success_callback,
- const ErrorCallback& error_callback) = 0;
+ const CommandDoneCallback& callback) = 0;
// Returns command with the given ID.
virtual void GetCommand(const std::string& id,
const UserInfo& user_info,
- const CommandSuccessCallback& success_callback,
- const ErrorCallback& error_callback) = 0;
+ const CommandDoneCallback& callback) = 0;
// Cancels command with the given ID.
virtual void CancelCommand(const std::string& id,
const UserInfo& user_info,
- const CommandSuccessCallback& success_callback,
- const ErrorCallback& error_callback) = 0;
+ const CommandDoneCallback& callback) = 0;
// Lists commands.
virtual void ListCommands(const UserInfo& user_info,
- const CommandSuccessCallback& success_callback,
- const ErrorCallback& error_callback) = 0;
+ const CommandDoneCallback& callback) = 0;
void AddObserver(Observer* observer) { observer_list_.AddObserver(observer); }
void RemoveObserver(Observer* observer) {
diff --git a/libweave/src/privet/mock_delegates.h b/libweave/src/privet/mock_delegates.h
index 48227ae..f16b526 100644
--- a/libweave/src/privet/mock_delegates.h
+++ b/libweave/src/privet/mock_delegates.h
@@ -154,25 +154,19 @@
MOCK_CONST_METHOD0(GetCloudId, std::string());
MOCK_CONST_METHOD0(GetState, const base::DictionaryValue&());
MOCK_CONST_METHOD0(GetCommandDef, const base::DictionaryValue&());
- MOCK_METHOD4(AddCommand,
+ MOCK_METHOD3(AddCommand,
void(const base::DictionaryValue&,
const UserInfo&,
- const CommandSuccessCallback&,
- const ErrorCallback&));
- MOCK_METHOD4(GetCommand,
+ const CommandDoneCallback&));
+ MOCK_METHOD3(GetCommand,
void(const std::string&,
const UserInfo&,
- const CommandSuccessCallback&,
- const ErrorCallback&));
- MOCK_METHOD4(CancelCommand,
+ const CommandDoneCallback&));
+ MOCK_METHOD3(CancelCommand,
void(const std::string&,
const UserInfo&,
- const CommandSuccessCallback&,
- const ErrorCallback&));
- MOCK_METHOD3(ListCommands,
- void(const UserInfo&,
- const CommandSuccessCallback&,
- const ErrorCallback&));
+ const CommandDoneCallback&));
+ MOCK_METHOD2(ListCommands, void(const UserInfo&, const CommandDoneCallback&));
MockCloudDelegate() {
EXPECT_CALL(*this, GetDeviceId()).WillRepeatedly(Return("TestId"));
diff --git a/libweave/src/privet/privet_handler.cc b/libweave/src/privet/privet_handler.cc
index 702104d..ad1d3d0 100644
--- a/libweave/src/privet/privet_handler.cc
+++ b/libweave/src/privet/privet_handler.cc
@@ -198,23 +198,20 @@
}
void OnCommandRequestSucceeded(const PrivetHandler::RequestCallback& callback,
- const base::DictionaryValue& output) {
- callback.Run(http::kOk, output);
-}
+ const base::DictionaryValue& output,
+ ErrorPtr error) {
+ if (!error)
+ return callback.Run(http::kOk, output);
-void OnCommandRequestFailed(const PrivetHandler::RequestCallback& callback,
- ErrorPtr error) {
if (error->HasError("gcd", "unknown_command")) {
- ErrorPtr new_error = error->Clone();
- Error::AddTo(&new_error, FROM_HERE, errors::kDomain, errors::kNotFound,
+ Error::AddTo(&error, FROM_HERE, errors::kDomain, errors::kNotFound,
"Unknown command ID");
- return ReturnError(*new_error, callback);
+ return ReturnError(*error, callback);
}
if (error->HasError("gcd", "access_denied")) {
- ErrorPtr new_error = error->Clone();
- Error::AddTo(&new_error, FROM_HERE, errors::kDomain, errors::kAccessDenied,
+ Error::AddTo(&error, FROM_HERE, errors::kDomain, errors::kAccessDenied,
error->GetMessage());
- return ReturnError(*new_error, callback);
+ return ReturnError(*error, callback);
}
return ReturnError(*error, callback);
}
@@ -768,8 +765,7 @@
const UserInfo& user_info,
const RequestCallback& callback) {
cloud_->AddCommand(input, user_info,
- base::Bind(&OnCommandRequestSucceeded, callback),
- base::Bind(&OnCommandRequestFailed, callback));
+ base::Bind(&OnCommandRequestSucceeded, callback));
}
void PrivetHandler::HandleCommandsStatus(const base::DictionaryValue& input,
@@ -784,16 +780,14 @@
return ReturnError(*error, callback);
}
cloud_->GetCommand(id, user_info,
- base::Bind(&OnCommandRequestSucceeded, callback),
- base::Bind(&OnCommandRequestFailed, callback));
+ base::Bind(&OnCommandRequestSucceeded, callback));
}
void PrivetHandler::HandleCommandsList(const base::DictionaryValue& input,
const UserInfo& user_info,
const RequestCallback& callback) {
cloud_->ListCommands(user_info,
- base::Bind(&OnCommandRequestSucceeded, callback),
- base::Bind(&OnCommandRequestFailed, callback));
+ base::Bind(&OnCommandRequestSucceeded, callback));
}
void PrivetHandler::HandleCommandsCancel(const base::DictionaryValue& input,
@@ -808,8 +802,7 @@
return ReturnError(*error, callback);
}
cloud_->CancelCommand(id, user_info,
- base::Bind(&OnCommandRequestSucceeded, callback),
- base::Bind(&OnCommandRequestFailed, callback));
+ base::Bind(&OnCommandRequestSucceeded, callback));
}
} // namespace privet
diff --git a/libweave/src/privet/privet_handler_unittest.cc b/libweave/src/privet/privet_handler_unittest.cc
index 42dd956..dba394a 100644
--- a/libweave/src/privet/privet_handler_unittest.cc
+++ b/libweave/src/privet/privet_handler_unittest.cc
@@ -649,8 +649,11 @@
base::DictionaryValue command;
LoadTestJson(kInput, &command);
LoadTestJson("{'id':'5'}", &command);
- EXPECT_CALL(cloud_, AddCommand(_, _, _, _))
- .WillOnce(RunCallback<2, const base::DictionaryValue&>(command));
+ EXPECT_CALL(cloud_, AddCommand(_, _, _))
+ .WillOnce(WithArgs<2>(Invoke(
+ [&command](const CloudDelegate::CommandDoneCallback& callback) {
+ callback.Run(command, nullptr);
+ })));
EXPECT_PRED2(IsEqualJson, "{'name':'test', 'id':'5'}",
HandleRequest("/privet/v3/commands/execute", kInput));
@@ -661,18 +664,22 @@
base::DictionaryValue command;
LoadTestJson(kInput, &command);
LoadTestJson("{'name':'test'}", &command);
- EXPECT_CALL(cloud_, GetCommand(_, _, _, _))
- .WillOnce(RunCallback<2, const base::DictionaryValue&>(command));
+ EXPECT_CALL(cloud_, GetCommand(_, _, _))
+ .WillOnce(WithArgs<2>(Invoke(
+ [&command](const CloudDelegate::CommandDoneCallback& callback) {
+ callback.Run(command, nullptr);
+ })));
EXPECT_PRED2(IsEqualJson, "{'name':'test', 'id':'5'}",
HandleRequest("/privet/v3/commands/status", kInput));
ErrorPtr error;
Error::AddTo(&error, FROM_HERE, errors::kDomain, "notFound", "");
- EXPECT_CALL(cloud_, GetCommand(_, _, _, _))
- .WillOnce(WithArgs<3>(Invoke([&error](const ErrorCallback& callback) {
- callback.Run(std::move(error));
- })));
+ EXPECT_CALL(cloud_, GetCommand(_, _, _))
+ .WillOnce(WithArgs<2>(
+ Invoke([&error](const CloudDelegate::CommandDoneCallback& callback) {
+ callback.Run({}, std::move(error));
+ })));
EXPECT_PRED2(IsEqualError, CodeWithReason(404, "notFound"),
HandleRequest("/privet/v3/commands/status", "{'id': '15'}"));
@@ -682,18 +689,22 @@
const char kExpected[] = "{'id': '5', 'name':'test', 'state':'cancelled'}";
base::DictionaryValue command;
LoadTestJson(kExpected, &command);
- EXPECT_CALL(cloud_, CancelCommand(_, _, _, _))
- .WillOnce(RunCallback<2, const base::DictionaryValue&>(command));
+ EXPECT_CALL(cloud_, CancelCommand(_, _, _))
+ .WillOnce(WithArgs<2>(Invoke(
+ [&command](const CloudDelegate::CommandDoneCallback& callback) {
+ callback.Run(command, nullptr);
+ })));
EXPECT_PRED2(IsEqualJson, kExpected,
HandleRequest("/privet/v3/commands/cancel", "{'id': '8'}"));
ErrorPtr error;
Error::AddTo(&error, FROM_HERE, errors::kDomain, "notFound", "");
- EXPECT_CALL(cloud_, CancelCommand(_, _, _, _))
- .WillOnce(WithArgs<3>(Invoke([&error](const ErrorCallback& callback) {
- callback.Run(std::move(error));
- })));
+ EXPECT_CALL(cloud_, CancelCommand(_, _, _))
+ .WillOnce(WithArgs<2>(
+ Invoke([&error](const CloudDelegate::CommandDoneCallback& callback) {
+ callback.Run({}, std::move(error));
+ })));
EXPECT_PRED2(IsEqualError, CodeWithReason(404, "notFound"),
HandleRequest("/privet/v3/commands/cancel", "{'id': '11'}"));
@@ -709,8 +720,11 @@
base::DictionaryValue commands;
LoadTestJson(kExpected, &commands);
- EXPECT_CALL(cloud_, ListCommands(_, _, _))
- .WillOnce(RunCallback<1, const base::DictionaryValue&>(commands));
+ EXPECT_CALL(cloud_, ListCommands(_, _))
+ .WillOnce(WithArgs<1>(Invoke(
+ [&commands](const CloudDelegate::CommandDoneCallback& callback) {
+ callback.Run(commands, nullptr);
+ })));
EXPECT_PRED2(IsEqualJson, kExpected,
HandleRequest("/privet/v3/commands/list", "{}"));
diff --git a/libweave/src/privet/wifi_bootstrap_manager.cc b/libweave/src/privet/wifi_bootstrap_manager.cc
index b4b0153..1d2d813 100644
--- a/libweave/src/privet/wifi_bootstrap_manager.cc
+++ b/libweave/src/privet/wifi_bootstrap_manager.cc
@@ -96,21 +96,12 @@
<< ", pass=" << passphrase << ").";
UpdateState(State::kConnecting);
task_runner_->PostDelayedTask(
- FROM_HERE, base::Bind(&WifiBootstrapManager::OnConnectError,
- tasks_weak_factory_.GetWeakPtr(), nullptr),
+ FROM_HERE, base::Bind(&WifiBootstrapManager::OnConnectTimeout,
+ tasks_weak_factory_.GetWeakPtr()),
base::TimeDelta::FromMinutes(3));
wifi_->Connect(ssid, passphrase,
- base::Bind(&WifiBootstrapManager::OnConnectSuccess,
- tasks_weak_factory_.GetWeakPtr(), ssid),
- base::Bind(&WifiBootstrapManager::OnConnectError,
- tasks_weak_factory_.GetWeakPtr()));
-}
-
-void WifiBootstrapManager::OnConnectError(ErrorPtr error) {
- Error::AddTo(&error, FROM_HERE, errors::kDomain, errors::kInvalidState,
- "Failed to connect to provided network");
- setup_state_ = SetupState{std::move(error)};
- StartBootstrapping();
+ base::Bind(&WifiBootstrapManager::OnConnectDone,
+ tasks_weak_factory_.GetWeakPtr(), ssid));
}
void WifiBootstrapManager::EndConnecting() {}
@@ -203,7 +194,14 @@
return {WifiType::kWifi24};
}
-void WifiBootstrapManager::OnConnectSuccess(const std::string& ssid) {
+void WifiBootstrapManager::OnConnectDone(const std::string& ssid,
+ ErrorPtr error) {
+ if (error) {
+ Error::AddTo(&error, FROM_HERE, errors::kDomain, errors::kInvalidState,
+ "Failed to connect to provided network");
+ setup_state_ = SetupState{std::move(error)};
+ return StartBootstrapping();
+ }
VLOG(1) << "Wifi was connected successfully";
Config::Transaction change{config_};
change.set_last_configured_ssid(ssid);
@@ -212,6 +210,14 @@
StartMonitoring();
}
+void WifiBootstrapManager::OnConnectTimeout() {
+ ErrorPtr error;
+ Error::AddTo(&error, FROM_HERE, errors::kDomain, errors::kInvalidState,
+ "Timeout connecting to provided network");
+ setup_state_ = SetupState{std::move(error)};
+ return StartBootstrapping();
+}
+
void WifiBootstrapManager::OnBootstrapTimeout() {
VLOG(1) << "Bootstrapping has timed out.";
StartMonitoring();
diff --git a/libweave/src/privet/wifi_bootstrap_manager.h b/libweave/src/privet/wifi_bootstrap_manager.h
index 390af31..c0a1c24 100644
--- a/libweave/src/privet/wifi_bootstrap_manager.h
+++ b/libweave/src/privet/wifi_bootstrap_manager.h
@@ -87,8 +87,8 @@
// to return to monitoring mode periodically in case our connectivity issues
// were temporary.
void OnBootstrapTimeout();
- void OnConnectSuccess(const std::string& ssid);
- void OnConnectError(ErrorPtr error);
+ void OnConnectDone(const std::string& ssid, ErrorPtr error);
+ void OnConnectTimeout();
void OnConnectivityChange();
void OnMonitorTimeout();
void UpdateConnectionState();
diff --git a/libweave/src/streams.cc b/libweave/src/streams.cc
index 1dc0355..0516f66 100644
--- a/libweave/src/streams.cc
+++ b/libweave/src/streams.cc
@@ -19,52 +19,54 @@
void MemoryStream::Read(void* buffer,
size_t size_to_read,
- const ReadSuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+ const ReadCallback& callback) {
CHECK_LE(read_position_, data_.size());
size_t size_read = std::min(size_to_read, data_.size() - read_position_);
if (size_read > 0)
memcpy(buffer, data_.data() + read_position_, size_read);
read_position_ += size_read;
task_runner_->PostDelayedTask(FROM_HERE,
- base::Bind(success_callback, size_read), {});
+ base::Bind(callback, size_read, nullptr), {});
}
void MemoryStream::Write(const void* buffer,
size_t size_to_write,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+ const WriteCallback& callback) {
data_.insert(data_.end(), static_cast<const char*>(buffer),
static_cast<const char*>(buffer) + size_to_write);
- task_runner_->PostDelayedTask(FROM_HERE, success_callback, {});
+ task_runner_->PostDelayedTask(FROM_HERE, base::Bind(callback, nullptr), {});
}
StreamCopier::StreamCopier(InputStream* source, OutputStream* destination)
: source_{source}, destination_{destination}, buffer_(4096) {}
-void StreamCopier::Copy(
- const InputStream::ReadSuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
- source_->Read(
- buffer_.data(), buffer_.size(),
- base::Bind(&StreamCopier::OnSuccessRead, weak_ptr_factory_.GetWeakPtr(),
- success_callback, error_callback),
- error_callback);
+void StreamCopier::Copy(const InputStream::ReadCallback& callback) {
+ source_->Read(buffer_.data(), buffer_.size(),
+ base::Bind(&StreamCopier::OnReadDone,
+ weak_ptr_factory_.GetWeakPtr(), callback));
}
-void StreamCopier::OnSuccessRead(
- const InputStream::ReadSuccessCallback& success_callback,
- const ErrorCallback& error_callback,
- size_t size) {
+void StreamCopier::OnReadDone(const InputStream::ReadCallback& callback,
+ size_t size,
+ ErrorPtr error) {
+ if (error)
+ return callback.Run(0, std::move(error));
+
size_done_ += size;
if (size) {
return destination_->Write(
buffer_.data(), size,
- base::Bind(&StreamCopier::Copy, weak_ptr_factory_.GetWeakPtr(),
- success_callback, error_callback),
- error_callback);
+ base::Bind(&StreamCopier::OnWriteDone, weak_ptr_factory_.GetWeakPtr(),
+ callback));
}
- success_callback.Run(size_done_);
+ callback.Run(size_done_, nullptr);
+}
+
+void StreamCopier::OnWriteDone(const InputStream::ReadCallback& callback,
+ ErrorPtr error) {
+ if (error)
+ return callback.Run(size_done_, std::move(error));
+ Copy(callback);
}
} // namespace weave
diff --git a/libweave/src/streams.h b/libweave/src/streams.h
index 0a21737..6d2a1d0 100644
--- a/libweave/src/streams.h
+++ b/libweave/src/streams.h
@@ -21,13 +21,11 @@
void Read(void* buffer,
size_t size_to_read,
- const ReadSuccessCallback& success_callback,
- const ErrorCallback& error_callback) override;
+ const ReadCallback& callback) override;
void Write(const void* buffer,
size_t size_to_write,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) override;
+ const WriteCallback& callback) override;
const std::vector<uint8_t>& GetData() const { return data_; }
@@ -41,13 +39,13 @@
public:
StreamCopier(InputStream* source, OutputStream* destination);
- void Copy(const InputStream::ReadSuccessCallback& success_callback,
- const ErrorCallback& error_callback);
+ void Copy(const InputStream::ReadCallback& callback);
private:
- void OnSuccessRead(const InputStream::ReadSuccessCallback& success_callback,
- const ErrorCallback& error_callback,
- size_t size);
+ void OnWriteDone(const InputStream::ReadCallback& callback, ErrorPtr error);
+ void OnReadDone(const InputStream::ReadCallback& callback,
+ size_t size,
+ ErrorPtr error);
InputStream* source_{nullptr};
OutputStream* destination_{nullptr};
diff --git a/libweave/src/streams_unittest.cc b/libweave/src/streams_unittest.cc
index 3cef6f0..b669328 100644
--- a/libweave/src/streams_unittest.cc
+++ b/libweave/src/streams_unittest.cc
@@ -23,13 +23,14 @@
bool done = false;
- auto on_success = base::Bind([&test_data, &done, &destination](size_t size) {
- done = true;
- EXPECT_EQ(test_data, destination.GetData());
- });
- auto on_error = base::Bind([](ErrorPtr error) { ADD_FAILURE(); });
+ auto callback = base::Bind(
+ [&test_data, &done, &destination](size_t size, ErrorPtr error) {
+ EXPECT_FALSE(error);
+ done = true;
+ EXPECT_EQ(test_data, destination.GetData());
+ });
StreamCopier copier{&source, &destination};
- copier.Copy(on_success, on_error);
+ copier.Copy(callback);
task_runner.Run(test_data.size());
EXPECT_TRUE(done);
diff --git a/libweave/src/test/fake_stream.cc b/libweave/src/test/fake_stream.cc
index 786aa01..7c86c4a 100644
--- a/libweave/src/test/fake_stream.cc
+++ b/libweave/src/test/fake_stream.cc
@@ -30,33 +30,30 @@
void FakeStream::Read(void* buffer,
size_t size_to_read,
- const ReadSuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+ const ReadCallback& callback) {
if (read_data_.empty()) {
task_runner_->PostDelayedTask(
- FROM_HERE,
- base::Bind(&FakeStream::Read, base::Unretained(this), buffer,
- size_to_read, success_callback, error_callback),
+ FROM_HERE, base::Bind(&FakeStream::Read, base::Unretained(this), buffer,
+ size_to_read, callback),
base::TimeDelta::FromSeconds(0));
return;
}
size_t size = std::min(size_to_read, read_data_.size());
memcpy(buffer, read_data_.data(), size);
read_data_ = read_data_.substr(size);
- task_runner_->PostDelayedTask(FROM_HERE, base::Bind(success_callback, size),
+ task_runner_->PostDelayedTask(FROM_HERE, base::Bind(callback, size, nullptr),
base::TimeDelta::FromSeconds(0));
}
void FakeStream::Write(const void* buffer,
size_t size_to_write,
- const SuccessCallback& success_callback,
- const ErrorCallback& error_callback) {
+ const WriteCallback& callback) {
size_t size = std::min(size_to_write, write_data_.size());
EXPECT_EQ(
write_data_.substr(0, size),
std::string(reinterpret_cast<const char*>(buffer), size_to_write));
write_data_ = write_data_.substr(size);
- task_runner_->PostDelayedTask(FROM_HERE, success_callback,
+ task_runner_->PostDelayedTask(FROM_HERE, base::Bind(callback, nullptr),
base::TimeDelta::FromSeconds(0));
}
diff --git a/libweave/src/weave_unittest.cc b/libweave/src/weave_unittest.cc
index 8d87229..74e8c41 100644
--- a/libweave/src/weave_unittest.cc
+++ b/libweave/src/weave_unittest.cc
@@ -140,9 +140,9 @@
void ExpectRequest(HttpClient::Method method,
const std::string& url,
const std::string& json_response) {
- EXPECT_CALL(http_client_, SendRequest(method, url, _, _, _, _))
+ EXPECT_CALL(http_client_, SendRequest(method, url, _, _, _))
.WillOnce(WithArgs<4>(Invoke([json_response](
- const HttpClient::SuccessCallback& callback) {
+ const HttpClient::SendRequestCallback& callback) {
std::unique_ptr<provider::test::MockHttpClientResponse> response{
new StrictMock<provider::test::MockHttpClientResponse>};
EXPECT_CALL(*response, GetStatusCode())
@@ -154,7 +154,7 @@
EXPECT_CALL(*response, GetData())
.Times(AtLeast(1))
.WillRepeatedly(Return(json_response));
- callback.Run(*response);
+ callback.Run(std::move(response), nullptr);
})));
}
@@ -306,7 +306,7 @@
}
TEST_F(WeaveBasicTest, Register) {
- EXPECT_CALL(network_, OpenSslSocket(_, _, _, _)).WillRepeatedly(Return());
+ EXPECT_CALL(network_, OpenSslSocket(_, _, _)).WillRepeatedly(Return());
StartDevice();
auto draft = CreateDictionaryValue(kDeviceResource);
@@ -333,12 +333,12 @@
InitDnsSdPublishing(true, "DB");
bool done = false;
- device_->Register("TICKET_ID", base::Bind([this, &done]() {
+ device_->Register("TICKET_ID", base::Bind([this, &done](ErrorPtr error) {
+ EXPECT_FALSE(error);
done = true;
task_runner_.Break();
EXPECT_EQ("CLOUD_ID", device_->GetSettings().cloud_id);
- }),
- base::Bind([](ErrorPtr error) { ADD_FAILURE(); }));
+ }));
task_runner_.Run();
EXPECT_TRUE(done);
}