Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions Kattbot.Tests/PetTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NSubstitute;
using SixLabors.ImageSharp;
using SixLabors.ImageSharp.PixelFormats;

namespace Kattbot.Tests;

Expand All @@ -14,16 +15,22 @@ namespace Kattbot.Tests;
public class PetTests
{
[TestMethod]
public async Task PetPetTest()
[DataRow("SamplePNGImage_100kbmb.png")]
[DataRow("SamplePNGImage_500kbmb.png")]
[DataRow("SamplePNGImage_1mbmb.png")]
[DataRow("SamplePNGImage_3mbmb.png")]
[DataRow("SamplePNGImage_10mbmb.png")]
[DataRow("SamplePNGImage_30mbmb.png")]
public async Task PetPetTest(string inputImage)
{
var puppeteerFactory = new PuppeteerFactory();

var logger = Substitute.For<ILogger<PetPetClient>>();

var makeEmojiClient = new PetPetClient(puppeteerFactory, logger);

string inputFile = Path.Combine(Path.GetTempPath(), "froge.png");
string ouputFile = Path.Combine(Path.GetTempPath(), "pet_froge.gif");
string inputFile = Path.Combine(Path.GetTempPath(), "test_images", inputImage);
string ouputFile = Path.Combine(Path.GetTempPath(), "pet-test-output", $"pet_{inputImage.Split(".")[0]}.gif");

byte[] resultBytes = await makeEmojiClient.PetPet(inputFile);

Expand All @@ -32,17 +39,36 @@ public async Task PetPetTest()
await image.SaveAsGifAsync(ouputFile);
}

[TestMethod]
public async Task CropToCircle()
[DataTestMethod]
[DataRow("froge.png")]
[DataRow("test_working.png")]
[DataRow("test_not_working.png")]
public async Task CropToCircle(string inputFilename)
{
string inputFile = Path.Combine(Path.GetTempPath(), inputFilename);
string ouputFile = Path.Combine(Path.GetTempPath(), "kattbot", $"cropped_{inputFilename}");

var imageService = new ImageService(null!);

using var image = Image.Load<Rgba32>(inputFile);

var croppedImage = imageService.CropToCircle(image);

await croppedImage.SaveAsPngAsync(ouputFile);
}

[DataTestMethod]
[DataRow("froge.png")]
public async Task Twirl(string inputFilename)
{
string inputFile = Path.Combine(Path.GetTempPath(), "froge.png");
string ouputFile = Path.Combine(Path.GetTempPath(), "froge_circle.png");
string inputFile = Path.Combine(Path.GetTempPath(), inputFilename);
string ouputFile = Path.Combine(Path.GetTempPath(), "kattbot", $"twirled_{inputFilename}");

var imageService = new ImageService(null!);

using var image = Image.Load(inputFile);
using var image = Image.Load<Rgba32>(inputFile);

var croppedImage = imageService.CropImageToCircle(image);
var croppedImage = imageService.TwirlImage(image);

await croppedImage.SaveAsPngAsync(ouputFile);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using System;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using DSharpPlus.CommandsNext;
using DSharpPlus.Entities;
using Kattbot.Helpers;
using Kattbot.Services.Dalle;
using Kattbot.Services.Images;
using MediatR;
Expand All @@ -14,23 +14,23 @@ namespace Kattbot.CommandHandlers.Images;
#pragma warning disable SA1402 // File may only contain a single type
public class DallePromptCommand : CommandRequest
{
public string Prompt { get; set; }

public DallePromptCommand(CommandContext ctx, string prompt)
: base(ctx)
{
Prompt = prompt;
}

public string Prompt { get; set; }
}

public class DallePromptCommandHandler : IRequestHandler<DallePromptCommand>
public class DallePromptHandler : IRequestHandler<DallePromptCommand>
{
private const int MaxEmbedTitleLength = 256;

private readonly DalleHttpClient _dalleHttpClient;
private readonly ImageService _imageService;

public DallePromptCommandHandler(DalleHttpClient dalleHttpClient, ImageService imageService)
public DallePromptHandler(DalleHttpClient dalleHttpClient, ImageService imageService)
{
_dalleHttpClient = dalleHttpClient;
_imageService = imageService;
Expand All @@ -42,7 +42,7 @@ public async Task Handle(DallePromptCommand request, CancellationToken cancellat

try
{
var response = await _dalleHttpClient.CreateImage(new CreateImageRequest { Prompt = request.Prompt });
var response = await _dalleHttpClient.CreateImage(new CreateImageRequest { Prompt = request.Prompt, User = request.Ctx.User.Id.ToString() });

if (response.Data == null || !response.Data.Any()) throw new Exception("Empty result");

Expand All @@ -52,11 +52,7 @@ public async Task Handle(DallePromptCommand request, CancellationToken cancellat

var imageStream = await _imageService.GetImageStream(image);

string safeFileName = new(Encoding.ASCII.GetString(Encoding.ASCII.GetBytes(request.Prompt))
.Select(c => char.IsLetterOrDigit(c) ? c : '_')
.ToArray());

string fileName = $"{safeFileName}.{imageStream.FileExtension}";
var fileName = request.Prompt.ToSafeFilename(imageStream.FileExtension);

var truncatedPrompt = request.Prompt.Length > MaxEmbedTitleLength
? $"{request.Prompt[..(MaxEmbedTitleLength - 3)]}..."
Expand Down
215 changes: 215 additions & 0 deletions Kattbot/CommandHandlers/Images/DallifyImage.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using DSharpPlus.CommandsNext;
using DSharpPlus.Entities;
using Kattbot.Helpers;
using Kattbot.Services.Dalle;
using Kattbot.Services.Images;
using MediatR;

namespace Kattbot.CommandHandlers.Images;

#pragma warning disable SA1402 // File may only contain a single type
public class DallifyEmoteRequest : CommandRequest
{
public DallifyEmoteRequest(CommandContext ctx, DiscordEmoji emoji)
: base(ctx)
{
Emoji = emoji;
}

public DiscordEmoji Emoji { get; set; }
}

public class DallifyUserRequest : CommandRequest
{
public DallifyUserRequest(CommandContext ctx, DiscordUser user)
: base(ctx)
{
User = user;
}

public DiscordUser User { get; set; }
}

public class DallifyImageRequest : CommandRequest
{
public DallifyImageRequest(CommandContext ctx)
: base(ctx)
{ }
}

public class DallifyImageHandler : IRequestHandler<DallifyEmoteRequest>,
IRequestHandler<DallifyUserRequest>,
IRequestHandler<DallifyImageRequest>
{
public static readonly int Size256 = 256;
public static readonly int Size512 = 512;
public static readonly int Size1024 = 1024;

private const int MaxImageSizeInMb = 4;

private static readonly int[] ValidSizes = { Size256, Size512, Size1024 };

private readonly DalleHttpClient _dalleHttpClient;
private readonly ImageService _imageService;
private readonly DiscordResolver _discordResolver;

public DallifyImageHandler(DalleHttpClient dalleHttpClient, ImageService imageService, DiscordResolver discordResolver)
{
_dalleHttpClient = dalleHttpClient;
_imageService = imageService;
_discordResolver = discordResolver;
}

public async Task Handle(DallifyEmoteRequest request, CancellationToken cancellationToken)
{
var ctx = request.Ctx;
var userId = ctx.User.Id;

var emoji = request.Emoji;

var message = await request.Ctx.RespondAsync("Working on it");

try
{
var imageUrl = emoji.GetEmojiImageUrl();

var imageStreamResult = await DallifyImage(imageUrl, userId, Size256);

using var imageStream = imageStreamResult.MemoryStream;
var fileExtension = imageStreamResult.FileExtension;

var imageName = emoji.Id != 0 ? emoji.Id.ToString() : emoji.Name;

string fileName = $"{imageName}.{fileExtension}";

var mb = new DiscordMessageBuilder()
.AddFile(fileName, imageStream)
.WithContent($"There you go {request.Ctx.Member?.Mention ?? "Unknown user"}");

await message.DeleteAsync();

await request.Ctx.RespondAsync(mb);
}
catch (Exception)
{
await message.DeleteAsync();
throw;
}
}

public async Task Handle(DallifyUserRequest request, CancellationToken cancellationToken)
{
var ctx = request.Ctx;
var user = request.User;
var guild = ctx.Guild;

var message = await request.Ctx.RespondAsync("Working on it");

try
{
var userAsMember = await _discordResolver.ResolveGuildMember(guild, user.Id) ?? throw new Exception("Invalid user");

var imageUrl = userAsMember.GuildAvatarUrl
?? userAsMember.AvatarUrl
?? throw new Exception("Couldn't load user avatar");

var imageStreamResult = await DallifyImage(imageUrl, user.Id, Size512);

using var imageStream = imageStreamResult.MemoryStream;
var fileExtension = imageStreamResult.FileExtension;

var imageFilename = user.GetNicknameOrUsername().ToSafeFilename(fileExtension);

DiscordMessageBuilder mb = new DiscordMessageBuilder()
.AddFile(imageFilename, imageStream)
.WithContent($"There you go {request.Ctx.Member?.Mention ?? "Unknown user"}");

await message.DeleteAsync();

await request.Ctx.RespondAsync(mb);
}
catch (Exception)
{
await message.DeleteAsync();
throw;
}
}

public async Task Handle(DallifyImageRequest request, CancellationToken cancellationToken)
{
var ctx = request.Ctx;
var user = ctx.User;
var message = ctx.Message;

var imageUrl = await message.GetImageUrlFromMessage();

if (imageUrl == null)
{
await ctx.RespondAsync("I didn't find any images.");
return;
}

var wokingOnItMessage = await request.Ctx.RespondAsync("Working on it");

try
{
var imageStreamResult = await DallifyImage(imageUrl, user.Id, Size1024);

using var imageStream = imageStreamResult.MemoryStream;
var fileExtension = imageStreamResult.FileExtension;

var imageFilename = $"{Guid.NewGuid()}.{fileExtension}";

DiscordMessageBuilder mb = new DiscordMessageBuilder()
.AddFile(imageFilename, imageStream)
.WithContent($"There you go {request.Ctx.Member?.Mention ?? "Unknown user"}");

await wokingOnItMessage.DeleteAsync();

await request.Ctx.RespondAsync(mb);
}
catch (Exception)
{
await wokingOnItMessage.DeleteAsync();
throw;
}
}

private async Task<ImageStreamResult> DallifyImage(string imageUrl, ulong userId, int maxSize)
{
var image = await _imageService.DownloadImage(imageUrl);

var imageAsPng = await _imageService.ConvertImageToPng(image, MaxImageSizeInMb);

var squaredImage = _imageService.CropToSquare(imageAsPng);

var resultSize = Math.Min(maxSize, Math.Max(ValidSizes.Reverse().FirstOrDefault(s => squaredImage.Height > s), ValidSizes[0]));

var fileName = $"{Guid.NewGuid()}.png";

var inputImageStream = await _imageService.GetImageStream(squaredImage);

var imageVariationRequest = new CreateImageVariationRequest
{
Image = inputImageStream.MemoryStream.ToArray(),
Size = $"{resultSize}x{resultSize}",
User = userId.ToString(),
};

var response = await _dalleHttpClient.CreateImageVariation(imageVariationRequest, fileName);

if (response.Data == null || !response.Data.Any()) throw new Exception("Empty result");

var imageResponseUrl = response.Data.First();

var imageResult = await _imageService.DownloadImage(imageResponseUrl.Url);

var imageStream = await _imageService.GetImageStream(imageResult);

return imageStream;
}
}
Loading