feat: add SSD1B and tiny-sd support (#897)

* feat: add code and doc for running SSD1B models

* Added some more lines to support SD1.x with TINY U-Nets too.

* support SSD-1B.safetensors

* fix sdv1.5 diffusers format loader

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
akleine
2025-10-25 17:35:54 +02:00
committed by GitHub
parent faabc5ad3c
commit 062490aa7c
6 changed files with 177 additions and 21 deletions

View File

@@ -330,6 +330,10 @@ std::string convert_cond_model_name(const std::string& name) {
return new_name;
}
if (new_name == "model.text_projection.weight") {
new_name = "transformer.text_model.text_projection";
}
if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) {
new_name = open_clip_to_hf_clip_model[new_name];
}
@@ -623,6 +627,14 @@ std::string convert_tensor_name(std::string name) {
if (starts_with(name, "diffusion_model")) {
name = "model." + name;
}
if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.0.")) {
name.replace(0, sizeof("model.diffusion_model.up_blocks.0.attentions.0.") - 1,
"model.diffusion_model.output_blocks.0.1.");
}
if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.1.")) {
name.replace(0, sizeof("model.diffusion_model.up_blocks.0.attentions.1.") - 1,
"model.diffusion_model.output_blocks.1.1.");
}
// size_t pos = name.find("lora_A");
// if (pos != std::string::npos) {
// name.replace(pos, strlen("lora_A"), "lora_up");
@@ -1776,6 +1788,7 @@ SDVersion ModelLoader::get_sd_version() {
bool is_wan = false;
int64_t patch_embedding_channels = 0;
bool has_img_emb = false;
bool has_middle_block_1 = false;
for (auto& tensor_storage : tensor_storages) {
if (!(is_xl || is_flux)) {
@@ -1822,6 +1835,10 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_SVD;
}
}
if (tensor_storage.name.find("model.diffusion_model.middle_block.1.") != std::string::npos ||
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
has_middle_block_1 = true;
}
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
@@ -1834,7 +1851,7 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
input_block_weight = tensor_storage;
input_block_checked = true;
if (is_xl || is_flux) {
if (is_flux) {
break;
}
}
@@ -1858,6 +1875,9 @@ SDVersion ModelLoader::get_sd_version() {
if (is_ip2p) {
return VERSION_SDXL_PIX2PIX;
}
if (!has_middle_block_1) {
return VERSION_SDXL_SSD1B;
}
return VERSION_SDXL;
}
@@ -1881,6 +1901,9 @@ SDVersion ModelLoader::get_sd_version() {
if (is_ip2p) {
return VERSION_SD1_PIX2PIX;
}
if (!has_middle_block_1) {
return VERSION_SD1_TINY_UNET;
}
return VERSION_SD1;
} else if (token_embedding_weight.ne[0] == 1024) {
if (is_inpaint) {