Refactored source modifiers

This commit is contained in:
Donné Napo 2024-04-25 17:15:21 +02:00
parent 8be76f224f
commit f9d04ef13e
5 changed files with 86 additions and 61 deletions

View File

@ -111,15 +111,16 @@ public partial class EntityLinker
return base.VisitAssignment(context); return base.VisitAssignment(context);
} }
public void AppendColumn(params string[] text) => _columnsBuffer.Add(string.Join('\n', text)); public void AppendColumn(string text) => _columnsBuffer.Add(text);
public void AppendColumns(params string[] text) => _columnsBuffer.Add(string.Join('\n', text));
public void AppendImportIfNotExists(string module, string import, params string[] text) public void Import(string from, string import)
{ {
var isExisting = _importCtxs.Select(GetOriginalText).Any(ctx => ctx.Contains(module) && ctx.Contains(import)); var isExisting = _importCtxs.Select(GetOriginalText).Any(ctx => ctx.Contains(from) && ctx.Contains(import));
var isBuffered = _importsBuffer.Any(txt => txt.Contains(module) && txt.Contains(import)); var isBuffered = _importsBuffer.Any(txt => txt.Contains(from) && txt.Contains(import));
if (!isExisting && !isBuffered) if (!isExisting && !isBuffered)
_importsBuffer.Add(string.Join('\n', text)); _importsBuffer.Add($"from {from} import {import}");
} }
public override string Rewrite() public override string Rewrite()

View File

@ -20,11 +20,11 @@ public partial class EntityLinker
var left = await LoadEntity(_left); var left = await LoadEntity(_left);
var right = await LoadEntity(_right); var right = await LoadEntity(_right);
left.AppendColumn([ left.AppendColumn(
$"\t{right.FieldName}: Mapped[\"{right.ClassName}\"] = relationship(back_populates=\"{left.FieldName}\")" $"\t{right.FieldName}: Mapped[\"{right.ClassName}\"] = relationship(back_populates=\"{left.FieldName}\")"
]); );
right.AppendColumn([ right.AppendColumns([
$"\t{left.FieldName}_id: Mapped[int] = mapped_column(ForeignKey(\"{left.TableName}.id\"))", $"\t{left.FieldName}_id: Mapped[int] = mapped_column(ForeignKey(\"{left.TableName}.id\"))",
$"\t{left.FieldName}: Mapped[\"{left.ClassName}\"] = relationship(back_populates=\"{right.FieldName}\")" $"\t{left.FieldName}: Mapped[\"{left.ClassName}\"] = relationship(back_populates=\"{right.FieldName}\")"
]); ]);
@ -38,16 +38,10 @@ public partial class EntityLinker
var left = await LoadEntity(_left); var left = await LoadEntity(_left);
var right = await LoadEntity(_right); var right = await LoadEntity(_right);
left.AppendImportIfNotExists( left.Import(from: "typing", import: "List");
"typing", left.AppendColumn($"\t{right.FieldName.Pluralize()}: Mapped[List[\"{right.ClassName}\"]] = relationship()");
"List",
"from typing import List"
);
left.AppendColumn([
$"\t{right.FieldName.Pluralize()}: Mapped[List[\"{right.ClassName}\"]] = relationship()"
]);
right.AppendColumn([ right.AppendColumns([
$"\t{left.FieldName}_id: Mapped[int] = mapped_column(ForeignKey(\"{left.TableName}.id\"))", $"\t{left.FieldName}_id: Mapped[int] = mapped_column(ForeignKey(\"{left.TableName}.id\"))",
$"\t{left.FieldName}: Mapped[\"{left.ClassName}\"] = relationship(back_populates=\"{right.FieldName}\")" $"\t{left.FieldName}: Mapped[\"{left.ClassName}\"] = relationship(back_populates=\"{right.FieldName}\")"
]); ]);
@ -61,22 +55,15 @@ public partial class EntityLinker
var left = await LoadEntity(_left); var left = await LoadEntity(_left);
var right = await LoadEntity(_right); var right = await LoadEntity(_right);
left.AppendImportIfNotExists( left.Import(from: "typing", import: "Optional");
"typing", left.AppendColumns([
"Optional",
"from typing import Optional"
);
left.AppendColumn([
$"\t{right.FieldName}_id: Mapped[Optional[int]] = mapped_column(ForeignKey(\"{right.TableName}.id\"))", $"\t{right.FieldName}_id: Mapped[Optional[int]] = mapped_column(ForeignKey(\"{right.TableName}.id\"))",
$"\t{right.FieldName}: Mapped[\"{right.ClassName}\"] = relationship(back_populates=\"{left.FieldName.Pluralize()}\")" $"\t{right.FieldName}: Mapped[\"{right.ClassName}\"] = relationship(back_populates=\"{left.FieldName.Pluralize()}\")"
]); ]);
left.AppendImportIfNotExists( left.Import(from: "typing", import: "List");
"typing",
"List", right.AppendColumns([
"from typing import List"
);
right.AppendColumn([
$"\t{left.FieldName}_id: Mapped[int] = mapped_column(ForeignKey(\"{left.TableName}.id\"))", $"\t{left.FieldName}_id: Mapped[int] = mapped_column(ForeignKey(\"{left.TableName}.id\"))",
$"\t{left.FieldName.Pluralize()}: Mapped[List[\"{left.ClassName}\"]] = relationship(back_populates=\"{right.FieldName}\")" $"\t{left.FieldName.Pluralize()}: Mapped[List[\"{left.ClassName}\"]] = relationship(back_populates=\"{right.FieldName}\")"
]); ]);

View File

@ -4,21 +4,58 @@ namespace MycroForge.CLI.CodeGen;
public class OrmEnvInitializer : PythonSourceModifier public class OrmEnvInitializer : PythonSourceModifier
{ {
private PythonParser.Import_fromContext? _alembicImport;
private PythonParser.AssignmentContext? _targetMetaDataAssignment;
private PythonParser.AssignmentContext? _urlAssignmentContext;
private PythonParser.AssignmentContext? _connectableAssignmentContext;
public OrmEnvInitializer(string source) : base(source) public OrmEnvInitializer(string source) : base(source)
{ {
} }
public override string Rewrite()
{
var tree = Parser.file_input();
Visit(tree);
if (_alembicImport is null)
throw new Exception("Could not find import insertion point.");
if (_targetMetaDataAssignment is null)
throw new Exception("Could not find metadata insertion point.");
if (_urlAssignmentContext is null)
throw new Exception("Could not find url insertion point.");
if (_connectableAssignmentContext is null)
throw new Exception("Could not find connectable insertion point.");
Rewrite(_alembicImport, [
GetOriginalText(_alembicImport),
"from orm.settings import OrmSettings",
"from orm.entities.entity_base import EntityBase"
]);
Rewrite(_targetMetaDataAssignment, "target_metadata = EntityBase.metadata");
Rewrite(_urlAssignmentContext, "url = OrmSettings.get_connectionstring()");
const string indent = " ";
Rewrite(_connectableAssignmentContext, [
"url = OrmSettings.get_connectionstring()",
$"{indent}context.config.set_main_option('sqlalchemy.url', url)",
$"{indent}{GetOriginalText(_connectableAssignmentContext)}"
]);
return Rewriter.GetText();
}
public override object? VisitImport_from(PythonParser.Import_fromContext context) public override object? VisitImport_from(PythonParser.Import_fromContext context)
{ {
var text = GetOriginalText(context); var text = GetOriginalText(context);
if (text != "from alembic import context") return null; if (text == "from alembic import context")
_alembicImport = context;
Rewrite(context,
text,
"from orm.settings import OrmSettings",
"from orm.entities.entity_base import EntityBase"
);
return base.VisitImport_from(context); return base.VisitImport_from(context);
} }
@ -28,23 +65,13 @@ public class OrmEnvInitializer : PythonSourceModifier
var text = GetOriginalText(context); var text = GetOriginalText(context);
if (text == "target_metadata = None") if (text == "target_metadata = None")
{ _targetMetaDataAssignment = context;
Rewrite(context, "target_metadata = EntityBase.metadata");
}
else if (text == "url = config.get_main_option(\"sqlalchemy.url\")") else if (text == "url = config.get_main_option(\"sqlalchemy.url\")")
{ _urlAssignmentContext = context;
Rewrite(context, "url = OrmSettings.get_connectionstring()");
}
else if (text.StartsWith("connectable =")) else if (text.StartsWith("connectable ="))
{ _connectableAssignmentContext = context;
// Important note, the indent here is 4 spaces and not tab(s).
const string indent = " ";
Rewrite(context, [
"url = OrmSettings.get_connectionstring()",
$"{indent}context.config.set_main_option('sqlalchemy.url', url)",
$"{indent}{text}"
]);
}
return base.VisitAssignment(context); return base.VisitAssignment(context);
} }

View File

@ -6,6 +6,7 @@ public class OrmEnvUpdater : PythonSourceModifier
{ {
private readonly string _moduleName; private readonly string _moduleName;
private readonly string _className; private readonly string _className;
private PythonParser.Import_fromContext? _lastImport;
public OrmEnvUpdater(string source, string moduleName, string className) : base(source) public OrmEnvUpdater(string source, string moduleName, string className) : base(source)
{ {
@ -13,17 +14,31 @@ public class OrmEnvUpdater : PythonSourceModifier
_className = className; _className = className;
} }
public override object? VisitImport_from(PythonParser.Import_fromContext context) public override string Rewrite()
{ {
var text = GetOriginalText(context); var tree = Parser.file_input();
Visit(tree);
if (text != "from orm.entities.entity_base import EntityBase") return null; if (_lastImport is null)
throw new Exception("Could not find import insertion point.");
var text = GetOriginalText(_lastImport);
Rewrite(context, [ Rewrite(_lastImport, [
text, text,
$"from orm.entities.{_moduleName} import {_className}" $"from orm.entities.{_moduleName} import {_className}"
]); ]);
return Rewriter.GetText();
}
public override object? VisitImport_from(PythonParser.Import_fromContext context)
{
var text = GetOriginalText(context);
if (text.StartsWith("from orm.entities"))
_lastImport = context;
return base.VisitImport_from(context); return base.VisitImport_from(context);
} }
} }

View File

@ -18,12 +18,7 @@ public abstract class PythonSourceModifier : PythonParserBaseVisitor<object?>
Rewriter = new TokenStreamRewriter(Stream); Rewriter = new TokenStreamRewriter(Stream);
} }
public virtual string Rewrite() public abstract string Rewrite();
{
var tree = Parser.file_input();
Visit(tree);
return Rewriter.GetText();
}
protected string GetOriginalText(ParserRuleContext context) protected string GetOriginalText(ParserRuleContext context)
{ {