mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 03:01:18 +01:00
feat: add sd3.5 support (#445)
This commit is contained in:
20
model.cpp
20
model.cpp
@@ -430,6 +430,14 @@ std::string convert_tensor_name(std::string name) {
|
||||
if (starts_with(name, "diffusion_model")) {
|
||||
name = "model." + name;
|
||||
}
|
||||
// size_t pos = name.find("lora_A");
|
||||
// if (pos != std::string::npos) {
|
||||
// name.replace(pos, strlen("lora_A"), "lora_up");
|
||||
// }
|
||||
// pos = name.find("lora_B");
|
||||
// if (pos != std::string::npos) {
|
||||
// name.replace(pos, strlen("lora_B"), "lora_down");
|
||||
// }
|
||||
std::string new_name = name;
|
||||
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) {
|
||||
new_name = convert_open_clip_to_hf_clip(name);
|
||||
@@ -466,6 +474,9 @@ std::string convert_tensor_name(std::string name) {
|
||||
if (pos != std::string::npos) {
|
||||
new_name.replace(pos, strlen(".processor"), "");
|
||||
}
|
||||
// if (starts_with(new_name, "transformer.transformer_blocks") || starts_with(new_name, "transformer.single_transformer_blocks")) {
|
||||
// new_name = "model.diffusion_model." + new_name;
|
||||
// }
|
||||
pos = new_name.rfind("lora");
|
||||
if (pos != std::string::npos) {
|
||||
std::string name_without_network_parts = new_name.substr(0, pos - 1);
|
||||
@@ -1354,6 +1365,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
|
||||
SDVersion ModelLoader::get_sd_version() {
|
||||
TensorStorage token_embedding_weight;
|
||||
bool is_flux = false;
|
||||
bool is_sd3 = false;
|
||||
for (auto& tensor_storage : tensor_storages) {
|
||||
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
|
||||
return VERSION_FLUX_DEV;
|
||||
@@ -1361,8 +1373,11 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
|
||||
is_flux = true;
|
||||
}
|
||||
if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) {
|
||||
return VERSION_SD3_5_8B;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.23.") != std::string::npos) {
|
||||
return VERSION_SD3_2B;
|
||||
is_sd3 = true;
|
||||
}
|
||||
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
|
||||
return VERSION_SDXL;
|
||||
@@ -1387,6 +1402,9 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
if (is_flux) {
|
||||
return VERSION_FLUX_SCHNELL;
|
||||
}
|
||||
if (is_sd3) {
|
||||
return VERSION_SD3_2B;
|
||||
}
|
||||
if (token_embedding_weight.ne[0] == 768) {
|
||||
return VERSION_SD1;
|
||||
} else if (token_embedding_weight.ne[0] == 1024) {
|
||||
|
||||
Reference in New Issue
Block a user