refactor: introduce GGMLRunnerContext (#928)

* introduce GGMLRunnerContext

* add Flash Attention enable control through GGMLRunnerContext

* add conv2d_direct enable control through GGMLRunnerContext
This commit is contained in:
leejet
2025-11-02 02:11:04 +08:00
committed by GitHub
parent c42826b77c
commit 6103d86e2c
21 changed files with 1079 additions and 1199 deletions

View File

@@ -36,6 +36,7 @@ struct DiffusionModel {
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attn_enabled(bool enabled) = 0;
};
struct UNetModel : public DiffusionModel {
@@ -44,9 +45,8 @@ struct UNetModel : public DiffusionModel {
UNetModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) {
SDVersion version = VERSION_SD1)
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version) {
}
std::string get_desc() override {
@@ -77,6 +77,10 @@ struct UNetModel : public DiffusionModel {
return unet.unet.adm_in_channels;
}
void set_flash_attn_enabled(bool enabled) {
unet.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@@ -98,9 +102,8 @@ struct MMDiTModel : public DiffusionModel {
MMDiTModel(ggml_backend_t backend,
bool offload_params_to_cpu,
bool flash_attn = false,
const String2GGMLType& tensor_types = {})
: mmdit(backend, offload_params_to_cpu, flash_attn, tensor_types, "model.diffusion_model") {
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
}
std::string get_desc() override {
@@ -131,6 +134,10 @@ struct MMDiTModel : public DiffusionModel {
return 768 + 1280;
}
void set_flash_attn_enabled(bool enabled) {
mmdit.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@@ -153,9 +160,8 @@ struct FluxModel : public DiffusionModel {
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, use_mask) {
}
std::string get_desc() override {
@@ -186,6 +192,10 @@ struct FluxModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
flux.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@@ -213,9 +223,8 @@ struct WanModel : public DiffusionModel {
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_WAN2,
bool flash_attn = false)
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
SDVersion version = VERSION_WAN2)
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version) {
}
std::string get_desc() override {
@@ -246,6 +255,10 @@ struct WanModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
wan.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,
@@ -272,9 +285,8 @@ struct QwenImageModel : public DiffusionModel {
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_QWEN_IMAGE,
bool flash_attn = false)
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
SDVersion version = VERSION_QWEN_IMAGE)
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version) {
}
std::string get_desc() override {
@@ -305,6 +317,10 @@ struct QwenImageModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
qwen_image.set_flash_attention_enabled(enabled);
}
void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr,