diff --git a/MycroForge.CLI/CodeGen/EntityLinker.EntityModel.cs b/MycroForge.CLI/CodeGen/EntityLinker.EntityModel.cs index 466866f..6ff1a07 100644 --- a/MycroForge.CLI/CodeGen/EntityLinker.EntityModel.cs +++ b/MycroForge.CLI/CodeGen/EntityLinker.EntityModel.cs @@ -111,15 +111,16 @@ public partial class EntityLinker return base.VisitAssignment(context); } - public void AppendColumn(params string[] text) => _columnsBuffer.Add(string.Join('\n', text)); + public void AppendColumn(string text) => _columnsBuffer.Add(text); + public void AppendColumns(params string[] text) => _columnsBuffer.Add(string.Join('\n', text)); - public void AppendImportIfNotExists(string module, string import, params string[] text) + public void Import(string from, string import) { - var isExisting = _importCtxs.Select(GetOriginalText).Any(ctx => ctx.Contains(module) && ctx.Contains(import)); - var isBuffered = _importsBuffer.Any(txt => txt.Contains(module) && txt.Contains(import)); + var isExisting = _importCtxs.Select(GetOriginalText).Any(ctx => ctx.Contains(from) && ctx.Contains(import)); + var isBuffered = _importsBuffer.Any(txt => txt.Contains(from) && txt.Contains(import)); if (!isExisting && !isBuffered) - _importsBuffer.Add(string.Join('\n', text)); + _importsBuffer.Add($"from {from} import {import}"); } public override string Rewrite() diff --git a/MycroForge.CLI/CodeGen/EntityLinker.cs b/MycroForge.CLI/CodeGen/EntityLinker.cs index ad833e8..6881663 100644 --- a/MycroForge.CLI/CodeGen/EntityLinker.cs +++ b/MycroForge.CLI/CodeGen/EntityLinker.cs @@ -20,11 +20,11 @@ public partial class EntityLinker var left = await LoadEntity(_left); var right = await LoadEntity(_right); - left.AppendColumn([ + left.AppendColumn( $"\t{right.FieldName}: Mapped[\"{right.ClassName}\"] = relationship(back_populates=\"{left.FieldName}\")" - ]); + ); - right.AppendColumn([ + right.AppendColumns([ $"\t{left.FieldName}_id: Mapped[int] = mapped_column(ForeignKey(\"{left.TableName}.id\"))", $"\t{left.FieldName}: Mapped[\"{left.ClassName}\"] = relationship(back_populates=\"{right.FieldName}\")" ]); @@ -38,16 +38,10 @@ public partial class EntityLinker var left = await LoadEntity(_left); var right = await LoadEntity(_right); - left.AppendImportIfNotExists( - "typing", - "List", - "from typing import List" - ); - left.AppendColumn([ - $"\t{right.FieldName.Pluralize()}: Mapped[List[\"{right.ClassName}\"]] = relationship()" - ]); + left.Import(from: "typing", import: "List"); + left.AppendColumn($"\t{right.FieldName.Pluralize()}: Mapped[List[\"{right.ClassName}\"]] = relationship()"); - right.AppendColumn([ + right.AppendColumns([ $"\t{left.FieldName}_id: Mapped[int] = mapped_column(ForeignKey(\"{left.TableName}.id\"))", $"\t{left.FieldName}: Mapped[\"{left.ClassName}\"] = relationship(back_populates=\"{right.FieldName}\")" ]); @@ -61,22 +55,15 @@ public partial class EntityLinker var left = await LoadEntity(_left); var right = await LoadEntity(_right); - left.AppendImportIfNotExists( - "typing", - "Optional", - "from typing import Optional" - ); - left.AppendColumn([ + left.Import(from: "typing", import: "Optional"); + left.AppendColumns([ $"\t{right.FieldName}_id: Mapped[Optional[int]] = mapped_column(ForeignKey(\"{right.TableName}.id\"))", $"\t{right.FieldName}: Mapped[\"{right.ClassName}\"] = relationship(back_populates=\"{left.FieldName.Pluralize()}\")" ]); - left.AppendImportIfNotExists( - "typing", - "List", - "from typing import List" - ); - right.AppendColumn([ + left.Import(from: "typing", import: "List"); + + right.AppendColumns([ $"\t{left.FieldName}_id: Mapped[int] = mapped_column(ForeignKey(\"{left.TableName}.id\"))", $"\t{left.FieldName.Pluralize()}: Mapped[List[\"{left.ClassName}\"]] = relationship(back_populates=\"{right.FieldName}\")" ]); diff --git a/MycroForge.CLI/CodeGen/OrmEnvInitializer.cs b/MycroForge.CLI/CodeGen/OrmEnvInitializer.cs index 3063a3e..39e2671 100644 --- a/MycroForge.CLI/CodeGen/OrmEnvInitializer.cs +++ b/MycroForge.CLI/CodeGen/OrmEnvInitializer.cs @@ -4,21 +4,58 @@ namespace MycroForge.CLI.CodeGen; public class OrmEnvInitializer : PythonSourceModifier { + private PythonParser.Import_fromContext? _alembicImport; + private PythonParser.AssignmentContext? _targetMetaDataAssignment; + private PythonParser.AssignmentContext? _urlAssignmentContext; + private PythonParser.AssignmentContext? _connectableAssignmentContext; + public OrmEnvInitializer(string source) : base(source) { } + public override string Rewrite() + { + var tree = Parser.file_input(); + Visit(tree); + + if (_alembicImport is null) + throw new Exception("Could not find import insertion point."); + + if (_targetMetaDataAssignment is null) + throw new Exception("Could not find metadata insertion point."); + + if (_urlAssignmentContext is null) + throw new Exception("Could not find url insertion point."); + + if (_connectableAssignmentContext is null) + throw new Exception("Could not find connectable insertion point."); + + Rewrite(_alembicImport, [ + GetOriginalText(_alembicImport), + "from orm.settings import OrmSettings", + "from orm.entities.entity_base import EntityBase" + ]); + + Rewrite(_targetMetaDataAssignment, "target_metadata = EntityBase.metadata"); + + Rewrite(_urlAssignmentContext, "url = OrmSettings.get_connectionstring()"); + + const string indent = " "; + Rewrite(_connectableAssignmentContext, [ + "url = OrmSettings.get_connectionstring()", + $"{indent}context.config.set_main_option('sqlalchemy.url', url)", + $"{indent}{GetOriginalText(_connectableAssignmentContext)}" + ]); + + return Rewriter.GetText(); + } + public override object? VisitImport_from(PythonParser.Import_fromContext context) { var text = GetOriginalText(context); - if (text != "from alembic import context") return null; - - Rewrite(context, - text, - "from orm.settings import OrmSettings", - "from orm.entities.entity_base import EntityBase" - ); + if (text == "from alembic import context") + _alembicImport = context; return base.VisitImport_from(context); } @@ -28,23 +65,13 @@ public class OrmEnvInitializer : PythonSourceModifier var text = GetOriginalText(context); if (text == "target_metadata = None") - { - Rewrite(context, "target_metadata = EntityBase.metadata"); - } + _targetMetaDataAssignment = context; + else if (text == "url = config.get_main_option(\"sqlalchemy.url\")") - { - Rewrite(context, "url = OrmSettings.get_connectionstring()"); - } + _urlAssignmentContext = context; + else if (text.StartsWith("connectable =")) - { - // Important note, the indent here is 4 spaces and not tab(s). - const string indent = " "; - Rewrite(context, [ - "url = OrmSettings.get_connectionstring()", - $"{indent}context.config.set_main_option('sqlalchemy.url', url)", - $"{indent}{text}" - ]); - } + _connectableAssignmentContext = context; return base.VisitAssignment(context); } diff --git a/MycroForge.CLI/CodeGen/OrmEnvUpdater.cs b/MycroForge.CLI/CodeGen/OrmEnvUpdater.cs index 68877b2..f2b40f4 100644 --- a/MycroForge.CLI/CodeGen/OrmEnvUpdater.cs +++ b/MycroForge.CLI/CodeGen/OrmEnvUpdater.cs @@ -6,6 +6,7 @@ public class OrmEnvUpdater : PythonSourceModifier { private readonly string _moduleName; private readonly string _className; + private PythonParser.Import_fromContext? _lastImport; public OrmEnvUpdater(string source, string moduleName, string className) : base(source) { @@ -13,17 +14,31 @@ public class OrmEnvUpdater : PythonSourceModifier _className = className; } - public override object? VisitImport_from(PythonParser.Import_fromContext context) + public override string Rewrite() { - var text = GetOriginalText(context); + var tree = Parser.file_input(); + Visit(tree); - if (text != "from orm.entities.entity_base import EntityBase") return null; + if (_lastImport is null) + throw new Exception("Could not find import insertion point."); + + var text = GetOriginalText(_lastImport); - Rewrite(context, [ + Rewrite(_lastImport, [ text, $"from orm.entities.{_moduleName} import {_className}" ]); + return Rewriter.GetText(); + } + + public override object? VisitImport_from(PythonParser.Import_fromContext context) + { + var text = GetOriginalText(context); + + if (text.StartsWith("from orm.entities")) + _lastImport = context; + return base.VisitImport_from(context); } } \ No newline at end of file diff --git a/MycroForge.CLI/CodeGen/PythonSourceModifier.cs b/MycroForge.CLI/CodeGen/PythonSourceModifier.cs index 1f84b9b..f61a995 100644 --- a/MycroForge.CLI/CodeGen/PythonSourceModifier.cs +++ b/MycroForge.CLI/CodeGen/PythonSourceModifier.cs @@ -18,12 +18,7 @@ public abstract class PythonSourceModifier : PythonParserBaseVisitor Rewriter = new TokenStreamRewriter(Stream); } - public virtual string Rewrite() - { - var tree = Parser.file_input(); - Visit(tree); - return Rewriter.GetText(); - } + public abstract string Rewrite(); protected string GetOriginalText(ParserRuleContext context) {