mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 03:01:18 +01:00
Conv2D direct support (#744)
* Conv2DDirect for VAE stage * Enable only for Vulkan, reduced duplicated code * Cmake option to use conv2d direct * conv2d direct always on for opencl * conv direct as a flag * fix merge typo * Align conv2d behavior to flash attention's * fix readme * add conv2d direct for controlnet * add conv2d direct for esrgan * clean code, use enable_conv2d_direct/get_all_blocks * format code --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
11
vae.hpp
11
vae.hpp
@@ -534,6 +534,17 @@ struct AutoEncoderKL : public GGMLRunner {
|
||||
ae.init(params_ctx, tensor_types, prefix);
|
||||
}
|
||||
|
||||
void enable_conv2d_direct() {
|
||||
std::vector<GGMLBlock*> blocks;
|
||||
ae.get_all_blocks(blocks);
|
||||
for (auto block : blocks) {
|
||||
if (block->get_desc() == "Conv2d") {
|
||||
auto conv_block = (Conv2d*)block;
|
||||
conv_block->enable_direct();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return "vae";
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user