mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 03:01:18 +01:00
refactor: introduce GGMLRunnerContext (#928)
* introduce GGMLRunnerContext * add Flash Attention enable control through GGMLRunnerContext * add conv2d_direct enable control through GGMLRunnerContext
This commit is contained in:
@@ -36,6 +36,7 @@ struct DiffusionModel {
|
||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
||||
virtual size_t get_params_buffer_size() = 0;
|
||||
virtual int64_t get_adm_in_channels() = 0;
|
||||
virtual void set_flash_attn_enabled(bool enabled) = 0;
|
||||
};
|
||||
|
||||
struct UNetModel : public DiffusionModel {
|
||||
@@ -44,9 +45,8 @@ struct UNetModel : public DiffusionModel {
|
||||
UNetModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
SDVersion version = VERSION_SD1,
|
||||
bool flash_attn = false)
|
||||
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) {
|
||||
SDVersion version = VERSION_SD1)
|
||||
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version) {
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@@ -77,6 +77,10 @@ struct UNetModel : public DiffusionModel {
|
||||
return unet.unet.adm_in_channels;
|
||||
}
|
||||
|
||||
void set_flash_attn_enabled(bool enabled) {
|
||||
unet.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
@@ -98,9 +102,8 @@ struct MMDiTModel : public DiffusionModel {
|
||||
|
||||
MMDiTModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
bool flash_attn = false,
|
||||
const String2GGMLType& tensor_types = {})
|
||||
: mmdit(backend, offload_params_to_cpu, flash_attn, tensor_types, "model.diffusion_model") {
|
||||
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@@ -131,6 +134,10 @@ struct MMDiTModel : public DiffusionModel {
|
||||
return 768 + 1280;
|
||||
}
|
||||
|
||||
void set_flash_attn_enabled(bool enabled) {
|
||||
mmdit.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
@@ -153,9 +160,8 @@ struct FluxModel : public DiffusionModel {
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
SDVersion version = VERSION_FLUX,
|
||||
bool flash_attn = false,
|
||||
bool use_mask = false)
|
||||
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
|
||||
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, use_mask) {
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@@ -186,6 +192,10 @@ struct FluxModel : public DiffusionModel {
|
||||
return 768;
|
||||
}
|
||||
|
||||
void set_flash_attn_enabled(bool enabled) {
|
||||
flux.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
@@ -213,9 +223,8 @@ struct WanModel : public DiffusionModel {
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "model.diffusion_model",
|
||||
SDVersion version = VERSION_WAN2,
|
||||
bool flash_attn = false)
|
||||
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
|
||||
SDVersion version = VERSION_WAN2)
|
||||
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version) {
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@@ -246,6 +255,10 @@ struct WanModel : public DiffusionModel {
|
||||
return 768;
|
||||
}
|
||||
|
||||
void set_flash_attn_enabled(bool enabled) {
|
||||
wan.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
@@ -272,9 +285,8 @@ struct QwenImageModel : public DiffusionModel {
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "model.diffusion_model",
|
||||
SDVersion version = VERSION_QWEN_IMAGE,
|
||||
bool flash_attn = false)
|
||||
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
|
||||
SDVersion version = VERSION_QWEN_IMAGE)
|
||||
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version) {
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@@ -305,6 +317,10 @@ struct QwenImageModel : public DiffusionModel {
|
||||
return 768;
|
||||
}
|
||||
|
||||
void set_flash_attn_enabled(bool enabled) {
|
||||
qwen_image.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
|
||||
Reference in New Issue
Block a user