mycroforge/MycroForge.CLI/CodeGen/RequestClassGenerator.cs
2024-05-29 14:20:18 +02:00

161 lines
5.5 KiB
C#

using System.Text.RegularExpressions;
using Humanizer;
namespace MycroForge.CLI.CodeGen;
public class RequestClassGenerator
{
public record Import(string Name, List<string> Types)
{
// The Match method accounts for generic types like List[str] or Dict[str, Any]
public bool Match(string type) => Types.Any(t => type == t || type.StartsWith(t));
public string FindType(string type) => Types.First(t => type == t || type.StartsWith(t));
};
public record Field(string Name, string Type);
public enum Type
{
Create,
Update
}
private static readonly string[] Template =
[
"from pydantic import BaseModel",
"%imports%",
"",
"class %request_type%%entity_class_name%Request(BaseModel):",
"%fields%",
];
private static readonly Regex ImportInfoRegex = new(@"from\s+(.+)\s+import\s+(.+)");
private static readonly Regex FieldInfoRegex = new(@"([_a-zA-Z-0-9]+)\s*:\s*Mapped\s*\[\s*(.+)\s*\]\s*=\s*.+");
private readonly ProjectContext _context;
public RequestClassGenerator(ProjectContext context)
{
_context = context;
}
public async Task Generate(string path, string entity, Type type)
{
var entitySnakeCaseName = entity.Underscore().ToLower();
var entityClassName = entity.Pascalize();
var entitiesFolderPath = $"{Features.Db.FeatureName}/entities/{path}";
var entityFilePath = $"{entitiesFolderPath}/{entitySnakeCaseName}.py";
var entitySource = await _context.ReadFile(entityFilePath);
var fieldInfo = ReadFields(entitySource);
var fields = string.Join('\n', fieldInfo.Select(x => ToFieldString(x, type)));
var requestsFolderPath = $"{Features.Api.FeatureName}/requests/{path}";
var updateRequestFilePath =
$"{requestsFolderPath}/{type.ToString().ToLower()}_{entitySnakeCaseName}_request.py";
var service = string.Join("\n", Template)
.Replace("%imports%", GetImportString(entitySource, fieldInfo, type))
.Replace("%request_type%", type.ToString().Pascalize())
.Replace("%entity_class_name%", entityClassName)
.Replace("%fields%", fields)
;
await _context.CreateFile(updateRequestFilePath, service);
}
private string ToFieldString(Field field, Type type)
{
var @string = $"\t{field.Name}: ";
if (type == Type.Create)
{
@string += $"{field.Type} = None";
}
else if (type == Type.Update)
{
@string += $"Optional[{field.Type}] = None";
}
else throw new Exception($"Request type {type} is not supported.");
return @string;
}
private string GetImportString(string entitySource, List<Field> fields, Type type)
{
var imports = GetImports(entitySource);
var importStringBuffer = type == Type.Create
? new Dictionary<string, List<string>>()
: new Dictionary<string, List<string>> { ["typing"] = ["Optional"] };
foreach (var field in fields)
{
if (imports.FirstOrDefault(i => i.Match(field.Type)) is Import import)
{
if (!importStringBuffer.ContainsKey(import.Name))
{
importStringBuffer.Add(import.Name, []);
}
importStringBuffer[import.Name].Add(import.FindType(field.Type));
}
}
return string.Join("\n", importStringBuffer.Select(
pair => $"from {pair.Key} import {string.Join(", ", pair.Value)}\n")
);
}
private List<Field> ReadFields(string entitySource)
{
var fields = new List<Field>();
var matches = FieldInfoRegex.Matches(entitySource);
foreach (Match match in matches)
{
// Index 0 contains the full Regex match
var fullMatch = match.Groups[0].Value;
// Ignore relationship fields, these need to be done manually
if (fullMatch.IndexOf("=", StringComparison.Ordinal) <
fullMatch.IndexOf("relationship(", StringComparison.Ordinal)) continue;
var name = Clean(match.Groups[1].Value);
var type = Clean(match.Groups[2].Value);
fields.Add(new Field(name, type));
}
return fields;
}
private List<Import> GetImports(string entitySource)
{
var imports = new List<Import>();
var matches = ImportInfoRegex.Matches(entitySource);
foreach (Match match in matches)
{
// Index 0 contains the whole Regex match, so we ignore this, since we're only interested in the captured groups.
var name = Clean(match.Groups[1].Value);
var types = Clean(match.Groups[2].Value)
.Split(',')
.Select(s => s.Trim())
.ToArray();
imports.Add(new Import(name, [..types]));
}
if (imports.FirstOrDefault(i => i.Name == "typing") is Import typingImport)
{
typingImport.Types.AddRange(["Any", "Dict", "List", "Optional"]);
}
else
{
imports.Add(new("typing", ["Any", "Dict", "List", "Optional"]));
}
return imports;
}
private static string Clean(string value) => value.Replace(" ", string.Empty).Trim();
}