feat: force using f32 for some layers

This commit is contained in:
leejet
2024-08-25 13:53:16 +08:00
parent 79c9fe9556
commit 1bdc767aaf
4 changed files with 26 additions and 15 deletions

View File

@@ -1740,9 +1740,17 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
// Pass, do not convert
} else if (ends_with(name, ".bias")) {
// Pass, do not convert
} else if (contains(name, "img_in.") || contains(name, "time_in.in_layer.") || contains(name, "vector_in.in_layer.") || contains(name, "guidance_in.in_layer.") || contains(name, "final_layer.linear.")) {
} else if (contains(name, "img_in.") ||
contains(name, "time_in.in_layer.") ||
contains(name, "vector_in.in_layer.") ||
contains(name, "guidance_in.in_layer.") ||
contains(name, "final_layer.linear.")) {
// Pass, do not convert. For FLUX
} else if (contains(name, "x_embedder.") || contains(name, "t_embedder.") || contains(name, "y_embedder.") || contains(name, "context_embedder.")) {
} else if (contains(name, "x_embedder.") ||
contains(name, "t_embedder.") ||
contains(name, "y_embedder.") ||
contains(name, "pos_embed") ||
contains(name, "context_embedder.")) {
// Pass, do not convert. For MMDiT
} else if (contains(name, "time_embed.") || contains(name, "label_emb.")) {
// Pass, do not convert. For Unet