Util: 2D convolve with channels

This commit is contained in:
Vicki Pfau 2021-04-15 22:10:58 -07:00
parent 9f099eab0b
commit 950767e6ad
2 changed files with 45 additions and 0 deletions

View File

@ -26,6 +26,7 @@ void ConvolutionKernelFillCircle(struct ConvolutionKernel* kernel, bool normaliz
void Convolve1DPad0PackedS32(const int32_t* restrict src, int32_t* restrict dst, size_t length, const struct ConvolutionKernel* restrict kernel); void Convolve1DPad0PackedS32(const int32_t* restrict src, int32_t* restrict dst, size_t length, const struct ConvolutionKernel* restrict kernel);
void Convolve2DClampPacked8(const uint8_t* restrict src, uint8_t* restrict dst, size_t width, size_t height, size_t stride, const struct ConvolutionKernel* restrict kernel); void Convolve2DClampPacked8(const uint8_t* restrict src, uint8_t* restrict dst, size_t width, size_t height, size_t stride, const struct ConvolutionKernel* restrict kernel);
void Convolve2DClampChannels8(const uint8_t* restrict src, uint8_t* restrict dst, size_t width, size_t height, size_t stride, size_t channels, const struct ConvolutionKernel* restrict kernel);
CXX_GUARD_END CXX_GUARD_END

View File

@ -136,3 +136,47 @@ void Convolve2DClampPacked8(const uint8_t* restrict src, uint8_t* restrict dst,
} }
} }
} }
void Convolve2DClampChannels8(const uint8_t* restrict src, uint8_t* restrict dst, size_t width, size_t height, size_t stride, size_t channels, const struct ConvolutionKernel* restrict kernel) {
if (kernel->rank != 2) {
return;
}
size_t kx2 = kernel->dims[0] / 2;
size_t ky2 = kernel->dims[1] / 2;
size_t y;
for (y = 0; y < height; ++y) {
uint8_t* orow = &dst[y * stride];
size_t x;
for (x = 0; x < width; ++x) {
size_t c;
for (c = 0; c < channels; ++c) {
float sum = 0.f;
size_t ky;
for (ky = 0; ky < kernel->dims[1]; ++ky) {
size_t cy = 0;
if (y + ky > ky2) {
cy = y + ky - ky2;
}
if (cy >= height) {
cy = height - 1;
}
const uint8_t* irow = &src[cy * stride];
size_t kx;
for (kx = 0; kx < kernel->dims[0]; ++kx) {
size_t cx = 0;
if (x + kx > kx2) {
cx = x + kx - kx2;
}
if (cx >= width) {
cx = width - 1;
}
cx *= channels;
sum += irow[cx + c] * kernel->kernel[ky * kernel->dims[0] + kx];
}
}
*orow = sum;
++orow;
}
}
}
}