feat: add wan2.1/2.2 support (#778)

* add wan vae suppport

* add wan model support

* add umt5 support

* add wan2.1 t2i support

* make flash attn work with wan

* make wan a little faster

* add wan2.1 t2v support

* add wan gguf support

* add offload params to cpu support

* add wan2.1 i2v support

* crop image before resize

* set default fps to 16

* add diff lora support

* fix wan2.1 i2v

* introduce sd_sample_params_t

* add wan2.2 t2v support

* add wan2.2 14B i2v support

* add wan2.2 ti2v support

* add high noise lora support

* sync: update ggml submodule url

* avoid build failure on linux

* avoid build failure

* update ggml

* update ggml

* fix sd_version_is_wan

* update ggml, fix cpu im2col_3d

* fix ggml_nn_attention_ext mask

* add cache support to ggml runner

* fix the issue of illegal memory access

* unify image loading processing

* add wan2.1/2.2 FLF2V support

* fix end_image mask

* update to latest ggml

* add GGUFReader

* update docs
This commit is contained in:
leejet
2025-09-06 18:08:03 +08:00
committed by GitHub
parent 2eb3845df5
commit cb1d975e96
46 changed files with 768088 additions and 1427 deletions

147
model.cpp
View File

@@ -6,10 +6,12 @@
#include <unordered_map>
#include <vector>
#include "gguf_reader.hpp"
#include "model.h"
#include "stable-diffusion.h"
#include "util.h"
#include "vocab.hpp"
#include "vocab_umt5.hpp"
#include "ggml-alloc.h"
#include "ggml-backend.h"
@@ -88,6 +90,7 @@ const char* unused_tensors[] = {
"posterior_mean_coef1",
"posterior_mean_coef2",
"cond_stage_model.transformer.text_model.embeddings.position_ids",
"cond_stage_model.transformer.vision_model.embeddings.position_ids",
"cond_stage_model.model.logit_scale",
"cond_stage_model.model.text_projection",
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids",
@@ -141,6 +144,11 @@ std::unordered_map<std::string, std::string> open_clip_to_hk_clip_resblock = {
{"mlp.c_proj.weight", "mlp.fc2.weight"},
};
std::unordered_map<std::string, std::string> cond_model_name_map = {
{"transformer.vision_model.pre_layrnorm.weight", "transformer.vision_model.pre_layernorm.weight"},
{"transformer.vision_model.pre_layrnorm.bias", "transformer.vision_model.pre_layernorm.bias"},
};
std::unordered_map<std::string, std::string> vae_decoder_name_map = {
{"first_stage_model.decoder.mid.attn_1.to_k.bias", "first_stage_model.decoder.mid.attn_1.k.bias"},
{"first_stage_model.decoder.mid.attn_1.to_k.weight", "first_stage_model.decoder.mid.attn_1.k.weight"},
@@ -179,7 +187,7 @@ std::unordered_map<std::string, std::string> pmid_v2_name_map = {
"pmid.qformer_perceiver.token_proj.fc2.weight"},
};
std::string convert_open_clip_to_hf_clip(const std::string& name) {
std::string convert_cond_model_name(const std::string& name) {
std::string new_name = name;
std::string prefix;
if (contains(new_name, ".enc.")) {
@@ -268,6 +276,10 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) {
new_name = open_clip_to_hf_clip_model[new_name];
}
if (cond_model_name_map.find(new_name) != cond_model_name_map.end()) {
new_name = cond_model_name_map[new_name];
}
std::string open_clip_resblock_prefix = "model.transformer.resblocks.";
std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers.";
@@ -563,7 +575,7 @@ std::string convert_tensor_name(std::string name) {
// }
std::string new_name = name;
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) {
new_name = convert_open_clip_to_hf_clip(name);
new_name = convert_cond_model_name(name);
} else if (starts_with(name, "first_stage_model.decoder")) {
new_name = convert_vae_decoder_name(name);
} else if (starts_with(name, "pmid.qformer_perceiver")) {
@@ -592,9 +604,11 @@ std::string convert_tensor_name(std::string name) {
} else {
new_name = name;
}
} else if (ends_with(name, ".diff") || ends_with(name, ".diff_b")) {
new_name = "lora." + name;
} else if (contains(name, "lora_up") || contains(name, "lora_down") ||
contains(name, "lora.up") || contains(name, "lora.down") ||
contains(name, "lora_linear")) {
contains(name, "lora_linear") || ends_with(name, ".alpha")) {
size_t pos = new_name.find(".processor");
if (pos != std::string::npos) {
new_name.replace(pos, strlen(".processor"), "");
@@ -602,7 +616,11 @@ std::string convert_tensor_name(std::string name) {
// if (starts_with(new_name, "transformer.transformer_blocks") || starts_with(new_name, "transformer.single_transformer_blocks")) {
// new_name = "model.diffusion_model." + new_name;
// }
pos = new_name.rfind("lora");
if (ends_with(name, ".alpha")) {
pos = new_name.rfind("alpha");
} else {
pos = new_name.rfind("lora");
}
if (pos != std::string::npos) {
std::string name_without_network_parts = new_name.substr(0, pos - 1);
std::string network_part = new_name.substr(pos);
@@ -684,6 +702,13 @@ void preprocess_tensor(TensorStorage tensor_storage,
tensor_storage.unsqueeze();
}
// wan vae
if (ends_with(new_name, "gamma")) {
tensor_storage.reverse_ne();
tensor_storage.n_dims = 1;
tensor_storage.reverse_ne();
}
tensor_storage.name = new_name;
if (new_name.find("cond_stage_model") != std::string::npos &&
@@ -1030,10 +1055,38 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
gguf_context* ctx_gguf_ = NULL;
ggml_context* ctx_meta_ = NULL;
ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_});
ctx_gguf_ = gguf_init_from_file(file_path.c_str(), {true, &ctx_meta_});
if (!ctx_gguf_) {
LOG_ERROR("failed to open '%s'", file_path.c_str());
return false;
LOG_ERROR("failed to open '%s' with gguf_init_from_file. Try to open it with GGUFReader.", file_path.c_str());
GGUFReader gguf_reader;
if (!gguf_reader.load(file_path)) {
LOG_ERROR("failed to open '%s' with GGUFReader.", file_path.c_str());
return false;
}
size_t data_offset = gguf_reader.data_offset();
for (const auto& gguf_tensor_info : gguf_reader.tensors()) {
std::string name = gguf_tensor_info.name;
if (!starts_with(name, prefix)) {
name = prefix + name;
}
TensorStorage tensor_storage(
name,
gguf_tensor_info.type,
gguf_tensor_info.shape.data(),
gguf_tensor_info.shape.size(),
file_index,
data_offset + gguf_tensor_info.offset);
// LOG_DEBUG("%s %s", name.c_str(), tensor_storage.to_string().c_str());
tensor_storages.push_back(tensor_storage);
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
}
return true;
}
int n_tensors = gguf_get_n_tensors(ctx_gguf_);
@@ -1047,7 +1100,11 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
// LOG_DEBUG("%s", name.c_str());
TensorStorage tensor_storage(prefix + name, dummy->type, dummy->ne, ggml_n_dims(dummy), file_index, offset);
if (!starts_with(name, prefix)) {
name = prefix + name;
}
TensorStorage tensor_storage(name, dummy->type, dummy->ne, ggml_n_dims(dummy), file_index, offset);
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
@@ -1085,7 +1142,7 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
// https://huggingface.co/docs/safetensors/index
bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) {
LOG_DEBUG("init from '%s'", file_path.c_str());
LOG_DEBUG("init from '%s', prefix = '%s'", file_path.c_str(), prefix.c_str());
file_paths_.push_back(file_path);
size_t file_index = file_paths_.size() - 1;
std::ifstream file(file_path, std::ios::binary);
@@ -1150,6 +1207,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
std::string dtype = tensor_info["dtype"];
nlohmann::json shape = tensor_info["shape"];
if (dtype == "U8") {
continue;
}
size_t begin = tensor_info["data_offsets"][0].get<size_t>();
size_t end = tensor_info["data_offsets"][1].get<size_t>();
@@ -1171,12 +1232,11 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
}
if (n_dims == 5) {
if (ne[3] == 1 && ne[4] == 1) {
n_dims = 4;
} else {
LOG_ERROR("invalid tensor '%s'", name.c_str());
return false;
}
n_dims = 4;
ne[0] = ne[0] * ne[1];
ne[1] = ne[2];
ne[2] = ne[3];
ne[3] = ne[4];
}
// ggml_n_dims returns 1 for scalars
@@ -1184,7 +1244,11 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
n_dims = 1;
}
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
if (!starts_with(name, prefix)) {
name = prefix + name;
}
TensorStorage tensor_storage(name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
tensor_storage.reverse_ne();
size_t tensor_data_size = end - begin;
@@ -1569,7 +1633,11 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
reader.tensor_storage.file_index = file_index;
// if(strcmp(prefix.c_str(), "scarlett") == 0)
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
std::string name = reader.tensor_storage.name;
if (!starts_with(name, prefix)) {
name = prefix + name;
}
reader.tensor_storage.name = name;
tensor_storages.push_back(reader.tensor_storage);
add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type);
@@ -1641,12 +1709,14 @@ SDVersion ModelLoader::get_sd_version() {
bool has_multiple_encoders = false;
bool is_unet = false;
bool is_xl = false;
bool is_flux = false;
bool is_xl = false;
bool is_flux = false;
bool is_wan = false;
int64_t patch_embedding_channels = 0;
bool has_img_emb = false;
#define found_family (is_xl || is_flux)
for (auto& tensor_storage : tensor_storages) {
if (!found_family) {
if (!(is_xl || is_flux)) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true;
if (input_block_checked) {
@@ -1656,6 +1726,15 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
return VERSION_SD3;
}
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
is_wan = true;
}
if (tensor_storage.name.find("model.diffusion_model.patch_embedding.weight") != std::string::npos) {
patch_embedding_channels = tensor_storage.ne[3];
}
if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) {
has_img_emb = true;
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
is_unet = true;
if (has_multiple_encoders) {
@@ -1690,11 +1769,21 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
input_block_weight = tensor_storage;
input_block_checked = true;
if (found_family) {
if (is_xl || is_flux) {
break;
}
}
}
if (is_wan) {
LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels);
if (patch_embedding_channels == 184320 && !has_img_emb) {
return VERSION_WAN2_2_I2V;
}
if (patch_embedding_channels == 147456 && !has_img_emb) {
return VERSION_WAN2_2_TI2V;
}
return VERSION_WAN2;
}
bool is_inpaint = input_block_weight.ne[2] == 9;
bool is_ip2p = input_block_weight.ne[2] == 8;
if (is_xl) {
@@ -1850,6 +1939,11 @@ std::string ModelLoader::load_t5_tokenizer_json() {
return json_str;
}
std::string ModelLoader::load_umt5_tokenizer_json() {
std::string json_str(reinterpret_cast<const char*>(umt5_tokenizer_json_str), sizeof(umt5_tokenizer_json_str));
return json_str;
}
std::vector<TensorStorage> remove_duplicates(const std::vector<TensorStorage>& vec) {
std::vector<TensorStorage> res;
std::unordered_map<std::string, size_t> name_to_index_map;
@@ -1871,7 +1965,7 @@ std::vector<TensorStorage> remove_duplicates(const std::vector<TensorStorage>& v
return res;
}
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend) {
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
std::vector<TensorStorage> processed_tensor_storages;
for (auto& tensor_storage : tensor_storages) {
// LOG_DEBUG("%s", name.c_str());
@@ -2080,7 +2174,6 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
}
bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
ggml_backend_t backend,
std::set<std::string> ignore_tensors) {
std::set<std::string> tensor_names_in_file;
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
@@ -2120,7 +2213,7 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
return true;
};
bool success = load_tensors(on_new_tensor_cb, backend);
bool success = load_tensors(on_new_tensor_cb);
if (!success) {
LOG_ERROR("load tensors from file failed");
return false;
@@ -2151,7 +2244,7 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
std::vector<std::pair<std::string, ggml_type>> result;
for (const auto& item : splitString(tensor_type_rules, ',')) {
for (const auto& item : split_string(tensor_type_rules, ',')) {
if (item.size() == 0)
continue;
std::string::size_type pos = item.find('=');
@@ -2264,7 +2357,7 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
return true;
};
bool success = load_tensors(on_new_tensor_cb, backend);
bool success = load_tensors(on_new_tensor_cb);
ggml_backend_free(backend);
LOG_INFO("load tensors done");
LOG_INFO("trying to save tensors to %s", file_path.c_str());