mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-05 19:51:19 +01:00
fix: allow model and vae using different format
This commit is contained in:
@@ -3281,9 +3281,10 @@ struct LoraModel {
|
||||
bool load(ggml_backend_t backend_, std::string file_path) {
|
||||
backend = backend_;
|
||||
LOG_INFO("loading LoRA from '%s'", file_path.c_str());
|
||||
std::shared_ptr<ModelLoader> model_loader = std::shared_ptr<ModelLoader>(init_model_loader_from_file(file_path));
|
||||
ModelLoader model_loader;
|
||||
;
|
||||
|
||||
if (!model_loader) {
|
||||
if (!model_loader.init_from_file(file_path)) {
|
||||
LOG_ERROR("init lora model loader from file failed: '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
@@ -3299,10 +3300,10 @@ struct LoraModel {
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_type wtype = model_loader->get_sd_wtype();
|
||||
ggml_type wtype = model_loader.get_sd_wtype();
|
||||
|
||||
LOG_DEBUG("calculating buffer size");
|
||||
int64_t memory_buffer_size = model_loader->cal_mem_size();
|
||||
int64_t memory_buffer_size = model_loader.cal_mem_size();
|
||||
LOG_DEBUG("lora params backend buffer size = % 6.2f MB", memory_buffer_size / (1024.0 * 1024.0));
|
||||
|
||||
params_buffer_lora = ggml_backend_alloc_buffer(backend, memory_buffer_size);
|
||||
@@ -3320,7 +3321,7 @@ struct LoraModel {
|
||||
return true;
|
||||
};
|
||||
|
||||
model_loader->load_tensors(on_new_tensor_cb);
|
||||
model_loader.load_tensors(on_new_tensor_cb);
|
||||
|
||||
LOG_DEBUG("finished loaded lora");
|
||||
ggml_allocr_free(alloc);
|
||||
@@ -3664,21 +3665,21 @@ public:
|
||||
#endif
|
||||
#endif
|
||||
LOG_INFO("loading model from '%s'", model_path.c_str());
|
||||
std::shared_ptr<ModelLoader> model_loader = std::shared_ptr<ModelLoader>(init_model_loader_from_file(model_path));
|
||||
ModelLoader model_loader;
|
||||
|
||||
if (!model_loader) {
|
||||
if (!model_loader.init_from_file(model_path)) {
|
||||
LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (vae_path.size() > 0) {
|
||||
LOG_INFO("loading vae from '%s'", vae_path.c_str());
|
||||
if (!model_loader->init_from_file(vae_path, "vae.")) {
|
||||
if (!model_loader.init_from_file(vae_path, "vae.")) {
|
||||
LOG_WARN("loading vae from '%s' failed", vae_path.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
SDVersion version = model_loader->get_sd_version();
|
||||
SDVersion version = model_loader.get_sd_version();
|
||||
if (version == VERSION_COUNT) {
|
||||
LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str());
|
||||
return false;
|
||||
@@ -3687,7 +3688,7 @@ public:
|
||||
diffusion_model = UNetModel(version);
|
||||
LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]);
|
||||
if (wtype == GGML_TYPE_COUNT) {
|
||||
model_data_type = model_loader->get_sd_wtype();
|
||||
model_data_type = model_loader.get_sd_wtype();
|
||||
} else {
|
||||
model_data_type = wtype;
|
||||
}
|
||||
@@ -3697,7 +3698,7 @@ public:
|
||||
auto add_token = [&](const std::string& token, int32_t token_id) {
|
||||
cond_stage_model.tokenizer.add_token(token, token_id);
|
||||
};
|
||||
bool success = model_loader->load_vocab(add_token);
|
||||
bool success = model_loader.load_vocab(add_token);
|
||||
if (!success) {
|
||||
LOG_ERROR("get vocab from file failed: '%s'", model_path.c_str());
|
||||
return false;
|
||||
@@ -3794,7 +3795,7 @@ public:
|
||||
|
||||
// print_ggml_tensor(alphas_cumprod_tensor);
|
||||
|
||||
success = model_loader->load_tensors(on_new_tensor_cb);
|
||||
success = model_loader.load_tensors(on_new_tensor_cb);
|
||||
if (!success) {
|
||||
LOG_ERROR("load tensors from file failed");
|
||||
ggml_free(ctx);
|
||||
|
||||
Reference in New Issue
Block a user