-
Notifications
You must be signed in to change notification settings - Fork 534
feat: support for cancelling generations #1124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
d69eafa to
de7bad2
Compare
|
I was just checking this out, and it looks really promising! I made some edits to the design on my end, compartmentalizing the signal handler into it's own object file so it could be reused between the cli and server. I also just had a successful test against sd-server receiving a cancel from a client hangup! I whipped up a quick patch of the changes if that's helpful. The only other things I can think of that might be useful would be maybe adding an initializer_list to set what signals get captured, but I think SIG_USR1 was a good default choice. Let me know what you think! sd_cancel.patchdiff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 2dcd1d5..c1ae3b3 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -1,4 +1,8 @@
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
+add_library(signal_handler OBJECT common/signal_handler.cpp)
+target_include_directories(signal_handler PUBLIC ../include)
+
add_subdirectory(cli)
-add_subdirectory(server)
\ No newline at end of file
+add_subdirectory(server)
+
diff --git a/examples/cli/CMakeLists.txt b/examples/cli/CMakeLists.txt
index b30a2e8..f24af51 100644
--- a/examples/cli/CMakeLists.txt
+++ b/examples/cli/CMakeLists.txt
@@ -3,4 +3,5 @@ set(TARGET sd-cli)
add_executable(${TARGET} main.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT})
-target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)
\ No newline at end of file
+target_link_libraries(${TARGET} PRIVATE signal_handler)
+target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)
diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp
index 503177c..e8e093c 100644
--- a/examples/cli/main.cpp
+++ b/examples/cli/main.cpp
@@ -474,12 +474,8 @@ bool save_results(const SDCliParams& cli_params,
return sucessful_reults != 0;
}
-#if defined(__unix__) || defined(__APPLE__) || defined(_POSIX_VERSION)
-#define SD_ENABLE_SIGNAL_HANDLER
-static void set_signal_cancel_handler(sd_ctx_t* sd_ctx);
-#else
-#define set_signal_cancel_handler(SD_CTX) ((void)SD_CTX)
-#endif
+#include "common/signal_handler.hpp"
+
int main(int argc, const char* argv[]) {
if (argc > 1 && std::string(argv[1]) == "--version") {
@@ -848,58 +844,3 @@ int main(int argc, const char* argv[]) {
return 0;
}
-#ifdef SD_ENABLE_SIGNAL_HANDLER
-
-#include <atomic>
-#include <csignal>
-#include <thread>
-#include <unistd.h>
-
-// this lock is needed to avoid a race condition between
-// free_sd_ctx and a pending sd_cancel_generation call
-std::atomic_flag signal_lock = ATOMIC_FLAG_INIT;
-static int g_sigint_cnt;
-static sd_ctx_t* g_sd_ctx;
-
-static void sig_cancel_handler(int /* signum */)
-{
- if (!signal_lock.test_and_set(std::memory_order_acquire)) {
- if (g_sd_ctx != nullptr) {
- if (g_sigint_cnt == 1) {
- char msg[] = "\ngot cancel signal, cancelling new generations\n";
- write(2, msg, sizeof(msg)-1);
- /* first signal cancels only the remaining latents on a batch */
- sd_cancel_generation(g_sd_ctx, SD_CANCEL_NEW_LATENTS);
- ++g_sigint_cnt;
- } else {
- char msg[] = "\ngot cancel signal, cancelling everything\n";
- write(2, msg, sizeof(msg)-1);
- /* cancels everything */
- sd_cancel_generation(g_sd_ctx, SD_CANCEL_ALL);
- }
- }
- signal_lock.clear(std::memory_order_release);
- }
-}
-
-static void set_signal_cancel_handler(sd_ctx_t* sd_ctx)
-{
- if (g_sigint_cnt == 0) {
- g_sigint_cnt++;
- struct sigaction sa{};
- sa.sa_handler = sig_cancel_handler;
- sa.sa_flags = SA_RESTART;
- sigaction(SIGUSR1, &sa, nullptr);
- }
-
- while (signal_lock.test_and_set(std::memory_order_acquire)) {
- std::this_thread::yield();
- }
-
- g_sd_ctx = sd_ctx;
-
- signal_lock.clear(std::memory_order_release);
-}
-
-#endif
-
diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt
index d191260..bc2d331 100644
--- a/examples/server/CMakeLists.txt
+++ b/examples/server/CMakeLists.txt
@@ -3,4 +3,5 @@ set(TARGET sd-server)
add_executable(${TARGET} main.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT})
-target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)
\ No newline at end of file
+target_link_libraries(${TARGET} PRIVATE signal_handler)
+target_compile_features(${TARGET} PUBLIC c_std_11 cxx_std_17)
diff --git a/examples/server/main.cpp b/examples/server/main.cpp
index 0fb10c7..d8d27e6 100644
--- a/examples/server/main.cpp
+++ b/examples/server/main.cpp
@@ -7,6 +7,7 @@
#include <mutex>
#include <sstream>
#include <vector>
+#include <future>
#include "httplib.h"
#include "stable-diffusion.h"
@@ -268,6 +269,8 @@ struct LoraEntry {
std::string path;
};
+#include "common/signal_handler.hpp"
+
int main(int argc, const char** argv) {
if (argc > 1 && std::string(argv[1]) == "--version") {
std::cout << version_string() << "\n";
@@ -346,6 +349,8 @@ int main(int argc, const char** argv) {
[&](const LoraEntry& e) { return e.path == path; });
};
+ set_signal_cancel_handler(sd_ctx);
+
httplib::Server svr;
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
@@ -507,11 +512,20 @@ int main(int argc, const char** argv) {
sd_image_t* results = nullptr;
int num_results = 0;
+ std::future<void> ft = std::async(std::launch::async, [&]()
{
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
+ );
+
+ std::future_status ft_status;
+ do {
+ if (!ft.valid()) break;
+ ft_status = ft.wait_for(std::chrono::milliseconds(1000));
+ if (req.is_connection_closed()) std::raise(SIGUSR1);
+ } while (ft_status != std::future_status::ready);
for (int i = 0; i < num_results; i++) {
if (results[i].data == nullptr) { |
Co-authored-by: donington <[email protected]>
That was very helpful, thanks! And actually, the signal handler just complicates things for this: we can call I'm even tempted to drop the signal handling stuff from this PR, or maybe move it to a separate one, since the hangup handler ended up simpler and more portable. |
I literally just got back to the computer and was reading your changes, and this is extremely clean now! I guess I didn't take a close enough look to realize that your cancellation was entirely isolated already when I whipped up my example of using futures.
It's true. You could probably simplify it down to using |
Adds an
sd_cancel_generationfunction that can be called asynchronously to interrupt the current generation.The log handling is still a bit rough on the edges, but I wanted to gather more feedback before polishing it. I've included a flag to allow finer control of what to cancel: everything, or keep and decode already-generated latents but cancel the current and next generations. Would an extra "finish the already started latent but cancel the batch" mode be useful? Or should I simplify it instead, keeping just the cancel-everything mode?
The function should be safe to be called from the progress or preview callbacks, a separate thread, or a signal handler. I've included a Unix signal handler on
main.cppjust to be able to test it: the first Ctrl+C cancels the batch and the current gen, but still finishes the already generated latents, while a second Ctrl+C cancels everything (although it won't interrupt it in the middle of a generation step anymore).fixes #1036