Conv2D direct support (#744)

* Conv2DDirect for VAE stage

* Enable only for Vulkan, reduced duplicated code

* Cmake option to use conv2d direct

* conv2d direct always on for opencl

* conv direct as a flag

* fix merge typo

* Align conv2d behavior to flash attention's

* fix readme

* add conv2d direct for controlnet

* add conv2d direct for esrgan

* clean code, use enable_conv2d_direct/get_all_blocks

* format code

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
Daniele
2025-08-02 17:25:17 +00:00
committed by GitHub
parent f7f05fb185
commit 5b8996f74a
11 changed files with 151 additions and 7 deletions

11
vae.hpp
View File

@@ -534,6 +534,17 @@ struct AutoEncoderKL : public GGMLRunner {
ae.init(params_ctx, tensor_types, prefix);
}
void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
ae.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}
std::string get_desc() {
return "vae";
}