mycroforge/MycroForge.CLI/CodeGen/EntityLinker.EntityModel.cs
2024-05-29 14:20:18 +02:00

133 lines
4.7 KiB
C#

using Humanizer;
using MycroForge.Parsing;
namespace MycroForge.CLI.CodeGen;
public partial class EntityLinker
{
public class EntityModel : PythonSourceModifier
{
private readonly string _className;
private readonly string _path;
private readonly List<PythonParser.Import_fromContext> _importContexts;
private readonly List<string> _importsBuffer;
private PythonParser.Import_fromContext LastImport => _importContexts.Last();
private readonly List<PythonParser.Class_defContext> _classContexts;
private PythonParser.AssignmentContext _tableContext;
private readonly List<PythonParser.AssignmentContext> _columnContexts;
private readonly List<string> _columnsBuffer;
private PythonParser.AssignmentContext LastColumn => _columnContexts.Last();
public string ClassName => _className;
public string Path => _path;
public string FieldName => _className.Underscore().ToLower();
public string TableName => GetOriginalText(_tableContext)
.Replace("__tablename__", string.Empty)
.Replace("=", string.Empty)
.Replace("\"", string.Empty)
.Trim();
public EntityModel(string className, string path, string source) : base(source)
{
_className = className;
_path = path;
_importContexts = new();
_importsBuffer = new();
_classContexts = new();
_tableContext = default!;
_columnContexts = new();
_columnsBuffer = new();
}
public void Initialize()
{
var tree = Parser.file_input();
Visit(tree);
if (!_classContexts.Any(c => GetOriginalText(c).Contains(_className)))
throw new Exception($"Entity {_className} was not found in {_path}.");
if (_columnContexts.Count == 0)
throw new Exception($"Entity {_className} has no columns.");
_importsBuffer.Add(GetOriginalText(LastImport));
_columnsBuffer.Add(GetOriginalText(LastColumn));
InsertRelationshipImport();
InsertForeignKeyImport();
}
private void InsertRelationshipImport()
{
var relationship = _importContexts.FirstOrDefault(import =>
{
var text = GetOriginalText(import);
return text.Contains("sqlalchemy.orm") && text.Contains("relationship");
});
if (relationship is null)
_importsBuffer.Add("from sqlalchemy.orm import relationship");
}
private void InsertForeignKeyImport()
{
var foreignKey = _importContexts.FirstOrDefault(import =>
{
var text = GetOriginalText(import);
return text.Contains("sqlalchemy") && text.Contains("ForeignKey");
});
if (foreignKey is null)
_importsBuffer.Add("from sqlalchemy import ForeignKey");
}
public override object? VisitImport_from(PythonParser.Import_fromContext context)
{
_importContexts.Add(context);
return base.VisitImport_from(context);
}
public override object? VisitClass_def(PythonParser.Class_defContext context)
{
_classContexts.Add(context);
return base.VisitClass_def(context);
}
public override object? VisitAssignment(PythonParser.AssignmentContext context)
{
var text = GetOriginalText(context);
if (text.StartsWith("__tablename__"))
_tableContext = context;
if (text.Contains("Mapped["))
_columnContexts.Add(context);
return base.VisitAssignment(context);
}
public void AppendColumn(string text) => _columnsBuffer.Add(text);
public void AppendColumns(params string[] text) => _columnsBuffer.Add(string.Join('\n', text));
public void Import(string from, string import)
{
var exists = _importContexts.Select(GetOriginalText).Any(ctx => ctx.Contains(from) && ctx.Contains(import));
var buffered = _importsBuffer.Any(txt => txt.Contains(from) && txt.Contains(import));
if (!exists && !buffered)
_importsBuffer.Add($"from {from} import {import}");
}
public override string Rewrite()
{
Rewrite(LastImport, _importsBuffer.ToArray());
Rewrite(LastColumn, _columnsBuffer.ToArray());
return Rewriter.GetText();
}
}
}