feat: add wan2.1/2.2 support (#778)

* add wan vae suppport

* add wan model support

* add umt5 support

* add wan2.1 t2i support

* make flash attn work with wan

* make wan a little faster

* add wan2.1 t2v support

* add wan gguf support

* add offload params to cpu support

* add wan2.1 i2v support

* crop image before resize

* set default fps to 16

* add diff lora support

* fix wan2.1 i2v

* introduce sd_sample_params_t

* add wan2.2 t2v support

* add wan2.2 14B i2v support

* add wan2.2 ti2v support

* add high noise lora support

* sync: update ggml submodule url

* avoid build failure on linux

* avoid build failure

* update ggml

* update ggml

* fix sd_version_is_wan

* update ggml, fix cpu im2col_3d

* fix ggml_nn_attention_ext mask

* add cache support to ggml runner

* fix the issue of illegal memory access

* unify image loading processing

* add wan2.1/2.2 FLF2V support

* fix end_image mask

* update to latest ggml

* add GGUFReader

* update docs
This commit is contained in:
leejet
2025-09-06 18:08:03 +08:00
committed by GitHub
parent 2eb3845df5
commit cb1d975e96
46 changed files with 768088 additions and 1427 deletions

View File

@@ -4,8 +4,10 @@
#include "flux.hpp"
#include "mmdit.hpp"
#include "unet.hpp"
#include "wan.hpp"
struct DiffusionModel {
virtual std::string get_desc() = 0;
virtual void compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
@@ -32,10 +34,15 @@ struct UNetModel : public DiffusionModel {
UNetModelRunner unet;
UNetModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
: unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) {
}
std::string get_desc() {
return unet.get_desc();
}
void alloc_params_buffer() {
@@ -85,8 +92,13 @@ struct MMDiTModel : public DiffusionModel {
MMDiTRunner mmdit;
MMDiTModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {})
: mmdit(backend, tensor_types, "model.diffusion_model") {
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
}
std::string get_desc() {
return mmdit.get_desc();
}
void alloc_params_buffer() {
@@ -135,11 +147,16 @@ struct FluxModel : public DiffusionModel {
Flux::FluxRunner flux;
FluxModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
: flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
}
std::string get_desc() {
return flux.get_desc();
}
void alloc_params_buffer() {
@@ -184,4 +201,63 @@ struct FluxModel : public DiffusionModel {
}
};
struct WanModel : public DiffusionModel {
std::string prefix;
WAN::WanRunner wan;
WanModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_WAN2,
bool flash_attn = false)
: prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
}
std::string get_desc() {
return wan.get_desc();
}
void alloc_params_buffer() {
wan.alloc_params_buffer();
}
void free_params_buffer() {
wan.free_params_buffer();
}
void free_compute_buffer() {
wan.free_compute_buffer();
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
wan.get_param_tensors(tensors, prefix);
}
size_t get_params_buffer_size() {
return wan.get_params_buffer_size();
}
int64_t get_adm_in_channels() {
return 768;
}
void compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
std::vector<ggml_tensor*> ref_latents = {},
int num_video_frames = -1,
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
return wan.compute(n_threads, x, timesteps, context, y, c_concat, NULL, output, output_ctx);
}
};
#endif