mirror of
https://github.com/RPCS3/discord-bot.git
synced 2026-01-31 01:25:22 +01:00
implement pluggable ocr provider and add tesseract and florence2 impl
This commit is contained in:
@@ -5,6 +5,7 @@ using CompatBot.Database;
|
||||
using CompatBot.Database.Providers;
|
||||
using CompatBot.EventHandlers;
|
||||
using CompatBot.EventHandlers.LogParsing.SourceHandlers;
|
||||
using CompatBot.Ocr;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
|
||||
namespace CompatBot.Commands;
|
||||
@@ -20,11 +21,8 @@ internal static class BotStatus
|
||||
var latency = ctx.Client.GetConnectionLatency(Config.BotGuildId);
|
||||
var embed = new DiscordEmbedBuilder { Color = DiscordColor.Purple }
|
||||
.AddField("Current Uptime", Config.Uptime.Elapsed.AsShortTimespan(), true)
|
||||
.AddField("Discord Latency", $"{latency.TotalMilliseconds:0.0} ms", true);
|
||||
if (Config.AzureComputerVisionKey is {Length: >0})
|
||||
embed.AddField("Max OCR Queue", MediaScreenshotMonitor.MaxQueueLength.ToString(), true);
|
||||
else
|
||||
embed.AddField("Max OCR Queue", "-", true);
|
||||
.AddField("Discord Latency", $"{latency.TotalMilliseconds:0.0} ms", true)
|
||||
.AddField("Max OCR Queue", $"{OcrProvider.BackendName} / {MediaScreenshotMonitor.MaxQueueLength}", true);
|
||||
var osInfo = RuntimeInformation.OSDescription;
|
||||
if (Environment.OSVersion.Platform is PlatformID.Unix or PlatformID.MacOSX)
|
||||
osInfo = RuntimeInformation.RuntimeIdentifier;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
using System.IO;
|
||||
using System.Net;
|
||||
using System.Net.Http;
|
||||
using System.Net.Mime;
|
||||
using CompatBot.EventHandlers;
|
||||
using CompatBot.Utils.Extensions;
|
||||
using DSharpPlus.Commands.Processors.TextCommands;
|
||||
@@ -354,20 +355,20 @@ internal sealed class Vision
|
||||
{
|
||||
foreach (var embed in msg.Embeds)
|
||||
{
|
||||
if (embed.Image?.Url?.ToString() is string url)
|
||||
if (embed.Image?.Url?.ToString() is {Length: >0} url)
|
||||
yield return url;
|
||||
else if (embed.Thumbnail?.Url?.ToString() is string thumbUrl)
|
||||
else if (embed.Thumbnail?.Url?.ToString() is {Length: >0} thumbUrl)
|
||||
yield return thumbUrl;
|
||||
}
|
||||
}
|
||||
|
||||
internal static IEnumerable<DiscordAttachment> GetImageAttachments(DiscordMessage message)
|
||||
internal static IEnumerable<string> GetImageAttachments(DiscordMessage message)
|
||||
=> message.Attachments.Where(a =>
|
||||
a.FileName.EndsWith(".jpg", StringComparison.InvariantCultureIgnoreCase)
|
||||
|| a.FileName.EndsWith(".png", StringComparison.InvariantCultureIgnoreCase)
|
||||
|| a.FileName.EndsWith(".jpeg", StringComparison.InvariantCultureIgnoreCase)
|
||||
//|| a.FileName.EndsWith(".webp", StringComparison.InvariantCultureIgnoreCase)
|
||||
);
|
||||
a.MediaType is MediaTypeNames.Image.Jpeg
|
||||
or MediaTypeNames.Image.Png
|
||||
or MediaTypeNames.Image.Webp
|
||||
&& a.Url is {Length: >0}
|
||||
).Select(att => att.Url!);
|
||||
|
||||
private static string GetDescription(ImageDescriptionDetails description, AdultInfo adultInfo)
|
||||
{
|
||||
@@ -416,8 +417,8 @@ internal sealed class Vision
|
||||
return null;
|
||||
|
||||
var reactMsg = tctx.Message;
|
||||
if (GetImageAttachments(reactMsg).FirstOrDefault() is DiscordAttachment attachment)
|
||||
imageUrl = attachment.Url;
|
||||
if (GetImageAttachments(reactMsg).FirstOrDefault() is {} attUrl)
|
||||
imageUrl = attUrl;
|
||||
imageUrl = imageUrl?.Trim() ?? "";
|
||||
if (!string.IsNullOrEmpty(imageUrl)
|
||||
&& imageUrl.StartsWith('<')
|
||||
@@ -431,17 +432,16 @@ internal sealed class Vision
|
||||
|| str.StartsWith("last")
|
||||
|| str.StartsWith("previous")
|
||||
|| str.StartsWith("^"))
|
||||
&& ctx.Channel.PermissionsFor(
|
||||
await ctx.Client.GetMemberAsync(ctx.Guild, ctx.Client.CurrentUser).ConfigureAwait(false)
|
||||
).HasPermission(DiscordPermission.ReadMessageHistory))
|
||||
&& await ctx.Client.GetMemberAsync(ctx.Guild, ctx.Client.CurrentUser).ConfigureAwait(false) is {} member
|
||||
&& ctx.Channel.PermissionsFor(member).HasPermission(DiscordPermission.ReadMessageHistory))
|
||||
try
|
||||
{
|
||||
var previousMessages = (await ctx.Channel.GetMessagesBeforeCachedAsync(tctx.Message.Id, 10).ConfigureAwait(false))!;
|
||||
imageUrl = (
|
||||
from m in previousMessages
|
||||
where m.Attachments?.Count > 0
|
||||
from a in GetImageAttachments(m)
|
||||
select a.Url
|
||||
from url in GetImageAttachments(m)
|
||||
select url
|
||||
).FirstOrDefault();
|
||||
if (string.IsNullOrEmpty(imageUrl))
|
||||
{
|
||||
|
||||
@@ -46,6 +46,7 @@
|
||||
<PackageReference Include="DSharpPlus.Commands" Version="5.0.0-nightly-02520" />
|
||||
<PackageReference Include="DSharpPlus.Interactivity" Version="5.0.0-nightly-02520" />
|
||||
<PackageReference Include="DSharpPlus.Natives.Zstd" Version="1.5.7.21" />
|
||||
<PackageReference Include="Florence2" Version="24.11.53800" />
|
||||
<PackageReference Include="Google.Apis.Drive.v3" Version="1.69.0.3783" />
|
||||
<PackageReference Include="MathParser.org-mXparser" Version="6.1.0" />
|
||||
<PackageReference Include="MegaApiClient" Version="1.10.4" />
|
||||
@@ -72,6 +73,7 @@
|
||||
<PackageReference Include="SharpCompress" Version="0.40.0" />
|
||||
<PackageReference Include="SixLabors.ImageSharp.Drawing" Version="2.1.6" />
|
||||
<PackageReference Include="System.Linq.Async" Version="6.0.3" />
|
||||
<PackageReference Include="TesseractCSharp" Version="1.0.5" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Clients\CirrusCiClient\CirrusCiClient.csproj" />
|
||||
|
||||
@@ -89,6 +89,7 @@ internal static class Config
|
||||
public static string IrdCachePath => config.GetValue(nameof(IrdCachePath), "./ird/");
|
||||
public static string RedumpDatfileCachePath => config.GetValue(nameof(RedumpDatfileCachePath), "./datfile/");
|
||||
public static string RenameNameSuffix => config.GetValue(nameof(RenameNameSuffix), " (Rule 7)");
|
||||
public static string OcrBackend => config.GetValue(nameof(OcrBackend), "auto"); // possible values: auto, tesseract, florence2, azure
|
||||
|
||||
public static double GameTitleMatchThreshold => config.GetValue(nameof(GameTitleMatchThreshold), 0.57);
|
||||
public static byte[] CryptoSalt => Convert.FromBase64String(config.GetValue(nameof(CryptoSalt), ""));
|
||||
|
||||
@@ -24,6 +24,8 @@ internal static class ContentFilterMonitor
|
||||
if (message?.Author is null)
|
||||
message = await e.Channel.GetMessageAsync(e.Message.Id).ConfigureAwait(false);
|
||||
}
|
||||
if (message.Attachments.Any())
|
||||
MediaScreenshotMonitor.EnqueueOcrTask(message);
|
||||
await ContentFilter.IsClean(c, message).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
@@ -3,56 +3,54 @@ using CompatApiClient.Utils;
|
||||
using CompatBot.Commands;
|
||||
using CompatBot.Database;
|
||||
using CompatBot.Database.Providers;
|
||||
using CompatBot.Ocr;
|
||||
using CompatBot.Utils.Extensions;
|
||||
using Microsoft.Azure.CognitiveServices.Vision.ComputerVision;
|
||||
using Microsoft.Azure.CognitiveServices.Vision.ComputerVision.Models;
|
||||
|
||||
namespace CompatBot.EventHandlers;
|
||||
|
||||
internal sealed class MediaScreenshotMonitor
|
||||
{
|
||||
private readonly ComputerVisionClient cvClient = new(new ApiKeyServiceClientCredentials(Config.AzureComputerVisionKey)) {Endpoint = Config.AzureComputerVisionEndpoint};
|
||||
private readonly SemaphoreSlim workSemaphore = new(0);
|
||||
private readonly ConcurrentQueue<(MessageCreatedEventArgs evt, Guid readOperationId)> workQueue = new ConcurrentQueue<(MessageCreatedEventArgs args, Guid readOperationId)>();
|
||||
private static readonly SemaphoreSlim WorkSemaphore = new(0);
|
||||
private static readonly ConcurrentQueue<(DiscordMessage msg, string imgUrl)> WorkQueue = new();
|
||||
public DiscordClient Client { get; internal set; } = null!;
|
||||
public static int MaxQueueLength { get; private set; }
|
||||
|
||||
public async Task OnMessageCreated(DiscordClient client, MessageCreatedEventArgs evt)
|
||||
{
|
||||
if (string.IsNullOrEmpty(Config.AzureComputerVisionKey))
|
||||
if (!OcrProvider.IsAvailable)
|
||||
return;
|
||||
|
||||
var message = evt.Message;
|
||||
if (message == null)
|
||||
if (evt.Message is not {} message)
|
||||
return;
|
||||
|
||||
if (!Config.Moderation.OcrChannels.Contains(evt.Channel.Id))
|
||||
return;
|
||||
// if (!Config.Moderation.OcrChannels.Contains(evt.Channel.Id))
|
||||
// return;
|
||||
|
||||
#if !DEBUG
|
||||
if (message.Author.IsBotSafeCheck())
|
||||
return;
|
||||
|
||||
#if !DEBUG
|
||||
if (await message.Author.IsSmartlistedAsync(client).ConfigureAwait(false))
|
||||
return;
|
||||
#endif
|
||||
|
||||
EnqueueOcrTask(evt.Message);
|
||||
}
|
||||
|
||||
public static void EnqueueOcrTask(DiscordMessage message)
|
||||
{
|
||||
if (!message.Attachments.Any())
|
||||
return;
|
||||
|
||||
var images = Vision.GetImageAttachments(message).Select(att => att.Url)
|
||||
var images = Vision.GetImageAttachments(message)
|
||||
.Concat(Vision.GetImagesFromEmbeds(message))
|
||||
.ToList();
|
||||
var tasks = new List<Task<ReadHeaders>>(images.Count);
|
||||
foreach (var url in images)
|
||||
tasks.Add(cvClient.ReadAsync(url, cancellationToken: Config.Cts.Token));
|
||||
foreach (var t in tasks)
|
||||
{
|
||||
try
|
||||
{
|
||||
var headers = await t.ConfigureAwait(false);
|
||||
workQueue.Enqueue((evt, new(new Uri(headers.OperationLocation).Segments.Last())));
|
||||
workSemaphore.Release();
|
||||
WorkQueue.Enqueue((message, url));
|
||||
WorkSemaphore.Release();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
@@ -63,84 +61,66 @@ internal sealed class MediaScreenshotMonitor
|
||||
|
||||
public async Task ProcessWorkQueue()
|
||||
{
|
||||
if (string.IsNullOrEmpty(Config.AzureComputerVisionKey))
|
||||
if (!OcrProvider.IsAvailable)
|
||||
return;
|
||||
|
||||
Guid? reEnqueueId = null;
|
||||
do
|
||||
{
|
||||
await workSemaphore.WaitAsync(Config.Cts.Token).ConfigureAwait(false);
|
||||
await WorkSemaphore.WaitAsync(Config.Cts.Token).ConfigureAwait(false);
|
||||
if (Config.Cts.IsCancellationRequested)
|
||||
return;
|
||||
|
||||
MaxQueueLength = Math.Max(MaxQueueLength, workQueue.Count);
|
||||
if (!workQueue.TryDequeue(out var item))
|
||||
MaxQueueLength = Math.Max(MaxQueueLength, WorkQueue.Count);
|
||||
if (!WorkQueue.TryDequeue(out var item))
|
||||
continue;
|
||||
|
||||
if (item.readOperationId == reEnqueueId)
|
||||
{
|
||||
await Task.Delay(100).ConfigureAwait(false);
|
||||
reEnqueueId = null;
|
||||
if (Config.Cts.IsCancellationRequested)
|
||||
return;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
var result = await cvClient.GetReadResultAsync(item.readOperationId, Config.Cts.Token).ConfigureAwait(false);
|
||||
if (result.Status == OperationStatusCodes.Succeeded)
|
||||
if (await OcrProvider.GetTextAsync(item.imgUrl, Config.Cts.Token).ConfigureAwait(false) is {Length: >0} result
|
||||
&& !Config.Cts.Token.IsCancellationRequested)
|
||||
{
|
||||
if (result.AnalyzeResult?.ReadResults?.SelectMany(r => r.Lines).Any() ?? false)
|
||||
var cnt = true;
|
||||
var prefix = $"[{item.msg.Id % 100:00}]";
|
||||
var ocrTextBuf = new StringBuilder($"OCR result of message <{item.msg.JumpLink}>:").AppendLine();
|
||||
Config.Log.Debug($"{prefix} OCR result of message {item.msg.JumpLink}:");
|
||||
var duplicates = new HashSet<string>();
|
||||
ocrTextBuf.AppendLine(result.Sanitize());
|
||||
Config.Log.Debug($"{prefix} {result}");
|
||||
if (cnt
|
||||
&& await ContentFilter.FindTriggerAsync(FilterContext.Chat, result).ConfigureAwait(false) is Piracystring hit
|
||||
&& duplicates.Add(hit.String))
|
||||
{
|
||||
var cnt = true;
|
||||
var prefix = $"[{item.evt.Message.Id % 100:00}]";
|
||||
var ocrTextBuf = new StringBuilder($"OCR result of message <{item.evt.Message.JumpLink}>:").AppendLine();
|
||||
Config.Log.Debug($"{prefix} OCR result of message {item.evt.Message.JumpLink}:");
|
||||
var duplicates = new HashSet<string>();
|
||||
foreach (var r in result.AnalyzeResult.ReadResults)
|
||||
foreach (var l in r.Lines)
|
||||
FilterAction suppressFlags = 0;
|
||||
if ("media".Equals(item.msg.Channel?.Name))
|
||||
{
|
||||
ocrTextBuf.AppendLine(l.Text.Sanitize());
|
||||
Config.Log.Debug($"{prefix} {l.Text}");
|
||||
if (cnt
|
||||
&& await ContentFilter.FindTriggerAsync(FilterContext.Chat, l.Text).ConfigureAwait(false) is Piracystring hit
|
||||
&& duplicates.Add(hit.String))
|
||||
{
|
||||
FilterAction suppressFlags = 0;
|
||||
if ("media".Equals(item.evt.Channel.Name))
|
||||
suppressFlags = FilterAction.SendMessage | FilterAction.ShowExplain;
|
||||
await ContentFilter.PerformFilterActions(
|
||||
Client,
|
||||
item.evt.Message,
|
||||
hit,
|
||||
suppressFlags,
|
||||
l.Text,
|
||||
"🖼 Screenshot of an undesirable content",
|
||||
"Screenshot of an undesirable content"
|
||||
).ConfigureAwait(false);
|
||||
cnt &= !hit.Actions.HasFlag(FilterAction.RemoveContent) && !hit.Actions.HasFlag(FilterAction.IssueWarning);
|
||||
}
|
||||
suppressFlags = FilterAction.SendMessage | FilterAction.ShowExplain;
|
||||
}
|
||||
await ContentFilter.PerformFilterActions(
|
||||
Client,
|
||||
item.msg,
|
||||
hit,
|
||||
suppressFlags,
|
||||
result,
|
||||
"🖼 Screenshot of an undesirable content",
|
||||
"Screenshot of an undesirable content"
|
||||
).ConfigureAwait(false);
|
||||
cnt &= !hit.Actions.HasFlag(FilterAction.RemoveContent) && !hit.Actions.HasFlag(FilterAction.IssueWarning);
|
||||
}
|
||||
var ocrText = ocrTextBuf.ToString();
|
||||
var hasVkDiagInfo = ocrText.Contains("Vulkan Diagnostics Tool v")
|
||||
|| ocrText.Contains("VkDiag Version:");
|
||||
if (!cnt || hasVkDiagInfo)
|
||||
{
|
||||
try
|
||||
{
|
||||
var botSpamCh = await Client.GetChannelAsync(Config.ThumbnailSpamId).ConfigureAwait(false);
|
||||
await botSpamCh.SendAutosplitMessageAsync(ocrTextBuf, blockStart: "", blockEnd: "").ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Config.Log.Warn(ex);
|
||||
}
|
||||
var ocrText = ocrTextBuf.ToString();
|
||||
var hasVkDiagInfo = ocrText.Contains("Vulkan Diagnostics Tool v")
|
||||
|| ocrText.Contains("VkDiag Version:");
|
||||
if (!cnt || hasVkDiagInfo)
|
||||
try
|
||||
{
|
||||
var botSpamCh = await Client.GetChannelAsync(Config.ThumbnailSpamId).ConfigureAwait(false);
|
||||
await botSpamCh.SendAutosplitMessageAsync(ocrTextBuf, blockStart: "", blockEnd: "").ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Config.Log.Warn(ex);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (result.Status is OperationStatusCodes.NotStarted or OperationStatusCodes.Running)
|
||||
{
|
||||
workQueue.Enqueue(item);
|
||||
reEnqueueId ??= item.readOperationId;
|
||||
workSemaphore.Release();
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
|
||||
52
CompatBot/Ocr/Backend/AzureVision.cs
Normal file
52
CompatBot/Ocr/Backend/AzureVision.cs
Normal file
@@ -0,0 +1,52 @@
|
||||
using CompatApiClient.Utils;
|
||||
using Microsoft.Azure.CognitiveServices.Vision.ComputerVision;
|
||||
using Microsoft.Azure.CognitiveServices.Vision.ComputerVision.Models;
|
||||
|
||||
namespace CompatBot.Ocr.Backend;
|
||||
|
||||
public class AzureVision: IOcrBackend
|
||||
{
|
||||
private ComputerVisionClient cvClient;
|
||||
|
||||
public string Name => "azure";
|
||||
|
||||
public Task<bool> InitializeAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
if (Config.AzureComputerVisionKey is not { Length: > 0 })
|
||||
return Task.FromResult(false);
|
||||
|
||||
cvClient = new(new ApiKeyServiceClientCredentials(Config.AzureComputerVisionKey))
|
||||
{
|
||||
Endpoint = Config.AzureComputerVisionEndpoint
|
||||
};
|
||||
return Task.FromResult(true);
|
||||
}
|
||||
|
||||
public async Task<string> GetTextAsync(string imgUrl, CancellationToken cancellationToken)
|
||||
{
|
||||
var headers = await cvClient.ReadAsync(imgUrl, cancellationToken: cancellationToken).ConfigureAwait(false);
|
||||
var operationId = new Guid(new Uri(headers.OperationLocation).Segments.Last());
|
||||
ReadOperationResult? result;
|
||||
bool waiting;
|
||||
do
|
||||
{
|
||||
result = await cvClient.GetReadResultAsync(operationId, Config.Cts.Token).ConfigureAwait(false);
|
||||
waiting = result.Status is OperationStatusCodes.NotStarted or OperationStatusCodes.Running;
|
||||
if (waiting)
|
||||
await Task.Delay(1000, cancellationToken).ConfigureAwait(false);
|
||||
} while (waiting);
|
||||
if (result.Status is OperationStatusCodes.Succeeded)
|
||||
{
|
||||
if (result.AnalyzeResult?.ReadResults?.SelectMany(r => r.Lines).Any() ?? false)
|
||||
{
|
||||
var ocrTextBuf = new StringBuilder();
|
||||
foreach (var r in result.AnalyzeResult.ReadResults)
|
||||
foreach (var l in r.Lines)
|
||||
ocrTextBuf.AppendLine(l.Text);
|
||||
return ocrTextBuf.ToString();
|
||||
}
|
||||
}
|
||||
Config.Log.Warn($"Failed to OCR image {imgUrl}: {result.Status}");
|
||||
return "";
|
||||
}
|
||||
}
|
||||
38
CompatBot/Ocr/Backend/BackendBase.cs
Normal file
38
CompatBot/Ocr/Backend/BackendBase.cs
Normal file
@@ -0,0 +1,38 @@
|
||||
using System.IO;
|
||||
using System.Net.Http;
|
||||
using CompatApiClient.Compression;
|
||||
|
||||
namespace CompatBot.Ocr.Backend;
|
||||
|
||||
public abstract class BackendBase: IOcrBackend, IDisposable
|
||||
{
|
||||
protected static readonly HttpClient HttpClient = HttpClientFactory.Create(new CompressionMessageHandler());
|
||||
|
||||
public abstract string Name { get; }
|
||||
|
||||
public virtual Task<bool> InitializeAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (!Directory.Exists(ModelCachePath))
|
||||
Directory.CreateDirectory(ModelCachePath);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Config.Log.Error(e, $"Failed to create model cache folder '{ModelCachePath}'");
|
||||
return Task.FromResult(false);
|
||||
}
|
||||
return Task.FromResult(true);
|
||||
}
|
||||
|
||||
public abstract Task<string> GetTextAsync(string imgUrl, CancellationToken cancellationToken);
|
||||
|
||||
public virtual void Dispose() => HttpClient.Dispose();
|
||||
|
||||
protected string ModelCachePath => Path.Combine(
|
||||
Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData),
|
||||
"discord-bot",
|
||||
"ocr-models",
|
||||
Name.ToLowerInvariant()
|
||||
);
|
||||
}
|
||||
65
CompatBot/Ocr/Backend/Florence2.cs
Normal file
65
CompatBot/Ocr/Backend/Florence2.cs
Normal file
@@ -0,0 +1,65 @@
|
||||
using Florence2;
|
||||
|
||||
namespace CompatBot.Ocr.Backend;
|
||||
|
||||
public class Florence2: BackendBase
|
||||
{
|
||||
private Florence2Model model;
|
||||
|
||||
public override string Name => "florence2";
|
||||
|
||||
public override async Task<bool> InitializeAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
if (!await base.InitializeAsync(cancellationToken).ConfigureAwait(false))
|
||||
return false;
|
||||
|
||||
var modelSource = new FlorenceModelDownloader(ModelCachePath);
|
||||
try
|
||||
{
|
||||
var errors = false;
|
||||
await modelSource.DownloadModelsAsync(s =>
|
||||
{
|
||||
if (s.Error is { Length: > 0 } errorMsg)
|
||||
{
|
||||
Config.Log.Error($"Failed to download Florence2 model files: {errorMsg}");
|
||||
errors = true;
|
||||
}
|
||||
else if (s.Message is { Length: > 0 } msg)
|
||||
{
|
||||
Config.Log.Info($"Florence2 model download message: {msg}");
|
||||
}
|
||||
},
|
||||
Config.LoggerFactory.CreateLogger("florence2"),
|
||||
cancellationToken
|
||||
).ConfigureAwait(false);
|
||||
if (errors)
|
||||
return false;
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Config.Log.Error(e, "Failed to download Florence2 model files");
|
||||
return false;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
model = new(modelSource);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Config.Log.Error(e, "Failed to initialize Florence2 model");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
public override async Task<string> GetTextAsync(string imgUrl, CancellationToken cancellationToken)
|
||||
{
|
||||
await using var imgStream = await HttpClient.GetStreamAsync(imgUrl, cancellationToken).ConfigureAwait(false);
|
||||
var results = model.Run(TaskTypes.OCR_WITH_REGION, [imgStream], "", CancellationToken.None);
|
||||
var result = new StringBuilder();
|
||||
foreach (var box in results[0].OCRBBox)
|
||||
result.AppendLine(box.Text);
|
||||
return result.ToString().TrimEnd();
|
||||
}
|
||||
}
|
||||
8
CompatBot/Ocr/Backend/IOcrBackend.cs
Normal file
8
CompatBot/Ocr/Backend/IOcrBackend.cs
Normal file
@@ -0,0 +1,8 @@
|
||||
namespace CompatBot.Ocr.Backend;
|
||||
|
||||
public interface IOcrBackend
|
||||
{
|
||||
string Name { get; }
|
||||
Task<bool> InitializeAsync(CancellationToken cancellationToken);
|
||||
Task<string> GetTextAsync(string imgUrl, CancellationToken cancellationToken);
|
||||
}
|
||||
81
CompatBot/Ocr/Backend/Tesseract.cs
Normal file
81
CompatBot/Ocr/Backend/Tesseract.cs
Normal file
@@ -0,0 +1,81 @@
|
||||
using System.IO;
|
||||
using System.Net.Http;
|
||||
using CompatApiClient.Compression;
|
||||
using TesseractCSharp;
|
||||
using TesseractCSharp.Interop;
|
||||
|
||||
namespace CompatBot.Ocr.Backend;
|
||||
|
||||
internal class Tesseract: BackendBase, IDisposable
|
||||
{
|
||||
private TesseractEngine engine;
|
||||
|
||||
public override string Name => "tesseract";
|
||||
|
||||
public override async Task<bool> InitializeAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
if (!await base.InitializeAsync(cancellationToken).ConfigureAwait(false))
|
||||
return false;
|
||||
|
||||
try
|
||||
{
|
||||
NativeConstants.InitNativeLoader();
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Config.Log.Error(e, "Failed to load Tesseract native dependencies");
|
||||
return false;
|
||||
}
|
||||
|
||||
var engModelPath = Path.Combine(ModelCachePath, "eng.traineddata");
|
||||
if (!File.Exists(engModelPath))
|
||||
{
|
||||
try
|
||||
{
|
||||
using var client = HttpClientFactory.Create(new CompressionMessageHandler());
|
||||
// existing repos: tessdata_fast, tessdata, tessdata_best
|
||||
const string uri = "https://github.com/tesseract-ocr/tessdata_best/raw/refs/heads/main/eng.traineddata";
|
||||
await using var response = await client.GetStreamAsync(uri, cancellationToken).ConfigureAwait(false);
|
||||
await using var file = File.Open(engModelPath, new FileStreamOptions
|
||||
{
|
||||
Mode = FileMode.Create,
|
||||
Access = FileAccess.Write,
|
||||
Share = FileShare.None,
|
||||
Options = FileOptions.Asynchronous | FileOptions.SequentialScan,
|
||||
});
|
||||
await response.CopyToAsync(file, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Config.Log.Error(e, "Failed to download model data");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
engine = new(ModelCachePath, "eng", EngineMode.Default);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Config.Log.Error(e, "Failed to initialize Tesseract engine");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public override async Task<string> GetTextAsync(string imgUrl, CancellationToken cancellationToken)
|
||||
{
|
||||
var imgData = await HttpClient.GetByteArrayAsync(imgUrl, cancellationToken).ConfigureAwait(false);
|
||||
using var img = Pix.LoadFromMemory(imgData);
|
||||
using var page = engine.Process(img);
|
||||
return page.GetText() ?? "";
|
||||
}
|
||||
|
||||
public override void Dispose()
|
||||
{
|
||||
base.Dispose();
|
||||
engine.Dispose();
|
||||
}
|
||||
}
|
||||
59
CompatBot/Ocr/OcrProvider.cs
Normal file
59
CompatBot/Ocr/OcrProvider.cs
Normal file
@@ -0,0 +1,59 @@
|
||||
using System.Runtime.InteropServices;
|
||||
using CompatBot.Ocr.Backend;
|
||||
|
||||
namespace CompatBot.Ocr;
|
||||
|
||||
public static class OcrProvider
|
||||
{
|
||||
private static IOcrBackend? backend;
|
||||
|
||||
public static bool IsAvailable => backend is not null;
|
||||
public static string BackendName => backend?.Name ?? "not configured";
|
||||
|
||||
public static async Task InitializeAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
var backendName = Config.OcrBackend;
|
||||
if (GetBackend(backendName) is not {} result)
|
||||
{
|
||||
if (Config.AzureComputerVisionKey is { Length: > 0 })
|
||||
backendName = "azure";
|
||||
else if (GC.GetGCMemoryInfo().TotalAvailableMemoryBytes > 4L * 1024 * 1024 * 1024
|
||||
|| RuntimeInformation.OSArchitecture is not (Architecture.X64 or Architecture.X86))
|
||||
{
|
||||
backendName = "florence2";
|
||||
}
|
||||
else
|
||||
backendName = "tesseract";
|
||||
result = GetBackend(backendName)!;
|
||||
}
|
||||
if (await result.InitializeAsync(cancellationToken).ConfigureAwait(false))
|
||||
{
|
||||
backend = result;
|
||||
Config.Log.Info($"Initialized OCR backend {BackendName}");
|
||||
}
|
||||
}
|
||||
|
||||
public static async Task<string> GetTextAsync(string imageUrl, CancellationToken cancellationToken)
|
||||
{
|
||||
if (backend is null)
|
||||
return "";
|
||||
try
|
||||
{
|
||||
return await backend.GetTextAsync(imageUrl, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Config.Log.Warn(e, $"Failed to OCR image {imageUrl}");
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
private static IOcrBackend? GetBackend(string name)
|
||||
=> name.ToLowerInvariant() switch
|
||||
{
|
||||
"tesseract" => new Backend.Tesseract(),
|
||||
"florence2" => new Backend.Florence2(),
|
||||
"azure" => new Backend.AzureVision(),
|
||||
_ => null,
|
||||
};
|
||||
}
|
||||
@@ -9,6 +9,7 @@ using CompatBot.Commands.Processors;
|
||||
using CompatBot.Database;
|
||||
using CompatBot.Database.Providers;
|
||||
using CompatBot.EventHandlers;
|
||||
using CompatBot.Ocr;
|
||||
using CompatBot.Utils.Extensions;
|
||||
using DSharpPlus.Commands.Processors.TextCommands;
|
||||
using DSharpPlus.Commands.Processors.TextCommands.Parsing;
|
||||
@@ -119,7 +120,8 @@ internal static class Program
|
||||
Config.GetAzureDevOpsClient().GetPipelineDurationAsync(Config.Cts.Token),
|
||||
new GithubClient.Client(Config.GithubToken).GetPipelineDurationAsync(Config.Cts.Token),
|
||||
Config.GetCurrentGitRevisionAsync(Config.Cts.Token),
|
||||
Bot.UpdateCheckScheduledAsync(Config.Cts.Token)
|
||||
Bot.UpdateCheckScheduledAsync(Config.Cts.Token),
|
||||
OcrProvider.InitializeAsync(Config.Cts.Token)
|
||||
);
|
||||
|
||||
try
|
||||
|
||||
@@ -2,11 +2,10 @@ FROM mcr.microsoft.com/dotnet/sdk:9.0-noble AS base
|
||||
|
||||
# Native libgdiplus dependencies
|
||||
RUN apt-get update
|
||||
RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get install -y --allow-unauthenticated libc6-dev libgdiplus libx11-dev fonts-roboto tzdata
|
||||
|
||||
# debian-specific?
|
||||
#RUN rm -rf /var/lib/apt/lists/*
|
||||
#RUN ln -s /lib/x86_64-linux-gnu/libdl-2.24.so /lib/x86_64-linux-gnu/libdl.so
|
||||
RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get install -y --allow-unauthenticated libc6-dev libgdiplus libx11-dev fonts-roboto tzdata libarchive13t64 liblept5
|
||||
RUN wget https://archive.ubuntu.com/ubuntu/pool/main/t/tiff/libtiff5_4.3.0-6_amd64.deb
|
||||
RUN dpkg -i ./libtiff5_4.3.0-6_amd64.deb
|
||||
RUN rm ./libtiff5_4.3.0-6_amd64.deb
|
||||
|
||||
# Regular stuff
|
||||
#COPY packages /root/.nuget/packages/
|
||||
|
||||
@@ -18,7 +18,14 @@ Development Requirements
|
||||
Runtime Requirements
|
||||
--------------------
|
||||
* [.NET 9.0 SDK](https://dotnet.microsoft.com/download) or newer to run from sources
|
||||
* bot needs `dotnet` command to be available (i.e. alias for the Snap package)
|
||||
* Bot needs `dotnet` command to be available (i.e. alias for the Snap package)
|
||||
* OCR on Linux requires the following dependencies to be installed: `libarchive.so.13`, `liblept.so.5`, `libtiff.so.5`
|
||||
* On Ubuntu 22.04 this is provided by the following packages:
|
||||
```sh
|
||||
sudo apt install libarchive13t64 liblept5
|
||||
wget https://archive.ubuntu.com/ubuntu/pool/main/t/tiff/libtiff5_4.3.0-6_amd64.deb
|
||||
sudo dpkg -i ./libtiff5_4.3.0-6_amd64.deb
|
||||
```
|
||||
* Optionally Google API credentials to access Google Drive:
|
||||
* Create new project in the [Google Cloud Resource Manager](https://console.developers.google.com/cloud-resource-manager)
|
||||
* Select the project and enable [Google Drive API](https://console.developers.google.com/apis/library/drive.googleapis.com)
|
||||
|
||||
Reference in New Issue
Block a user