mycroforge/MycroForge.CLI/CodeGen/RequestClassGenerator.cs

186 lines
6.2 KiB
C#

using System.Text.RegularExpressions;
using Humanizer;
using MycroForge.CLI.Commands;
using MycroForge.Core;
namespace MycroForge.CLI.CodeGen;
public class RequestClassGenerator
{
private static readonly List<string> PythonTypingImports = ["Any", "Dict", "List", "Optional"];
private record Import(string Name, List<string> Types)
{
public bool Match(string type) => Types.Any(t => type == t || type.StartsWith(t));
public string FindType(string type) => Types.First(t => type == t || type.StartsWith(t));
};
private record Field(string Name, string Type);
public enum Type
{
Create,
Update
}
private static readonly string[] Template =
[
"from pydantic import BaseModel",
"%imports%",
"",
"class %request_type%%entity_class_name%Request(BaseModel):",
"%fields%",
];
private static readonly Regex ImportInfoRegex = new(@"from\s+(.+)\s+import\s+(.+)");
private static readonly Regex FieldInfoRegex = new(@"([_a-zA-Z-0-9]+)\s*:\s*Mapped\s*\[\s*(.+)\s*\]\s*=\s*.+");
private readonly ProjectContext _context;
public RequestClassGenerator(ProjectContext context)
{
_context = context;
}
public async Task Generate(FullyQualifiedName fqn, Type type)
{
var entityFilePath = Path.Join(Features.Db.FeatureName, "entities", $"{fqn.FilePath}.py");
var entitySource = await _context.ReadFile(entityFilePath);
var fieldInfo = ReadFields(entitySource);
var fields = string.Join('\n', fieldInfo.Select(x => ToFieldString(x, type)));
var requestFilePath = Path.Join(
Features.Api.FeatureName,
"requests",
fqn.Namespace,
$"{type.ToString().ToLower()}_{fqn.SnakeCasedName}_request.py"
);
var service = string.Join("\n", Template)
.Replace("%imports%", GetImportString(entitySource, fieldInfo, type))
.Replace("%request_type%", type.ToString().Pascalize())
.Replace("%entity_class_name%", fqn.PascalizedName)
.Replace("%fields%", fields)
;
await _context.CreateFile(requestFilePath, service);
}
private string ToFieldString(Field field, Type type)
{
var @string = $"\t{field.Name}: ";
if (type == Type.Create)
{
@string += $"{field.Type} = None";
}
else if (type == Type.Update)
{
@string += $"Optional[{field.Type}] = None";
}
else throw new Exception($"Request type {type} is not supported.");
return @string;
}
private string GetImportString(string entitySource, List<Field> fields, Type type)
{
var imports = GetImports(entitySource);
var importStringBuffer = type == Type.Create
? new Dictionary<string, List<string>>()
: new Dictionary<string, List<string>> { ["typing"] = ["Optional"] };
foreach (var field in fields)
{
/*
The following snippet will allow importing nested types if necessary.
var str = "List[Dict[str, Any]]";
str = str.Replace("[", ",")
.Replace("]", "")
.Replace(" ", "");
Console.WriteLine(str); // = "List,Dict,str,Any"
*/
var dissectedTypes = field.Type
.Replace("[", ",")
.Replace("]", "")
.Replace(" ", "")
.Split(',');
foreach (var dissectedType in dissectedTypes)
{
if (imports.FirstOrDefault(i => i.Match(dissectedType)) is Import import)
{
if (!importStringBuffer.ContainsKey(import.Name))
{
importStringBuffer.Add(import.Name, []);
}
importStringBuffer[import.Name].Add(import.FindType(field.Type));
}
}
}
return string.Join("\n", importStringBuffer.Select(
pair => $"from {pair.Key} import {string.Join(", ", pair.Value)}\n")
);
}
private List<Field> ReadFields(string entitySource)
{
var fields = new List<Field>();
var matches = FieldInfoRegex.Matches(entitySource);
foreach (Match match in matches)
{
// Index 0 contains the full Regex match
var fullMatch = match.Groups[0].Value;
// Ignore primary_key fields
if (fullMatch.IndexOf("=", StringComparison.Ordinal) <
fullMatch.IndexOf("primary_key", StringComparison.Ordinal)) continue;
// Ignore relationship fields, these need to be done manually
if (fullMatch.IndexOf("=", StringComparison.Ordinal) <
fullMatch.IndexOf("relationship(", StringComparison.Ordinal)) continue;
var name = Clean(match.Groups[1].Value);
var type = Clean(match.Groups[2].Value);
fields.Add(new Field(name, type));
}
return fields;
}
private List<Import> GetImports(string entitySource)
{
var imports = new List<Import>();
var matches = ImportInfoRegex.Matches(entitySource);
foreach (Match match in matches)
{
// Index 0 contains the whole Regex match, so we ignore this, since we're only interested in the captured groups.
var name = Clean(match.Groups[1].Value);
var types = Clean(match.Groups[2].Value)
.Split(',')
.Select(s => s.Trim())
.ToArray();
imports.Add(new Import(name, new List<string>(types)));
}
if (imports.FirstOrDefault(i => i.Name == "typing") is Import typingImport)
{
typingImport.Types.AddRange(PythonTypingImports);
}
else
{
imports.Add(new Import("typing", PythonTypingImports));
}
return imports;
}
private static string Clean(string value) => value.Replace(" ", string.Empty).Trim();
}