Refactored source modifiers
This commit is contained in:
parent
8be76f224f
commit
f9d04ef13e
@ -111,15 +111,16 @@ public partial class EntityLinker
|
||||
return base.VisitAssignment(context);
|
||||
}
|
||||
|
||||
public void AppendColumn(params string[] text) => _columnsBuffer.Add(string.Join('\n', text));
|
||||
public void AppendColumn(string text) => _columnsBuffer.Add(text);
|
||||
public void AppendColumns(params string[] text) => _columnsBuffer.Add(string.Join('\n', text));
|
||||
|
||||
public void AppendImportIfNotExists(string module, string import, params string[] text)
|
||||
public void Import(string from, string import)
|
||||
{
|
||||
var isExisting = _importCtxs.Select(GetOriginalText).Any(ctx => ctx.Contains(module) && ctx.Contains(import));
|
||||
var isBuffered = _importsBuffer.Any(txt => txt.Contains(module) && txt.Contains(import));
|
||||
var isExisting = _importCtxs.Select(GetOriginalText).Any(ctx => ctx.Contains(from) && ctx.Contains(import));
|
||||
var isBuffered = _importsBuffer.Any(txt => txt.Contains(from) && txt.Contains(import));
|
||||
|
||||
if (!isExisting && !isBuffered)
|
||||
_importsBuffer.Add(string.Join('\n', text));
|
||||
_importsBuffer.Add($"from {from} import {import}");
|
||||
}
|
||||
|
||||
public override string Rewrite()
|
||||
|
@ -20,11 +20,11 @@ public partial class EntityLinker
|
||||
var left = await LoadEntity(_left);
|
||||
var right = await LoadEntity(_right);
|
||||
|
||||
left.AppendColumn([
|
||||
left.AppendColumn(
|
||||
$"\t{right.FieldName}: Mapped[\"{right.ClassName}\"] = relationship(back_populates=\"{left.FieldName}\")"
|
||||
]);
|
||||
);
|
||||
|
||||
right.AppendColumn([
|
||||
right.AppendColumns([
|
||||
$"\t{left.FieldName}_id: Mapped[int] = mapped_column(ForeignKey(\"{left.TableName}.id\"))",
|
||||
$"\t{left.FieldName}: Mapped[\"{left.ClassName}\"] = relationship(back_populates=\"{right.FieldName}\")"
|
||||
]);
|
||||
@ -38,16 +38,10 @@ public partial class EntityLinker
|
||||
var left = await LoadEntity(_left);
|
||||
var right = await LoadEntity(_right);
|
||||
|
||||
left.AppendImportIfNotExists(
|
||||
"typing",
|
||||
"List",
|
||||
"from typing import List"
|
||||
);
|
||||
left.AppendColumn([
|
||||
$"\t{right.FieldName.Pluralize()}: Mapped[List[\"{right.ClassName}\"]] = relationship()"
|
||||
]);
|
||||
left.Import(from: "typing", import: "List");
|
||||
left.AppendColumn($"\t{right.FieldName.Pluralize()}: Mapped[List[\"{right.ClassName}\"]] = relationship()");
|
||||
|
||||
right.AppendColumn([
|
||||
right.AppendColumns([
|
||||
$"\t{left.FieldName}_id: Mapped[int] = mapped_column(ForeignKey(\"{left.TableName}.id\"))",
|
||||
$"\t{left.FieldName}: Mapped[\"{left.ClassName}\"] = relationship(back_populates=\"{right.FieldName}\")"
|
||||
]);
|
||||
@ -61,22 +55,15 @@ public partial class EntityLinker
|
||||
var left = await LoadEntity(_left);
|
||||
var right = await LoadEntity(_right);
|
||||
|
||||
left.AppendImportIfNotExists(
|
||||
"typing",
|
||||
"Optional",
|
||||
"from typing import Optional"
|
||||
);
|
||||
left.AppendColumn([
|
||||
left.Import(from: "typing", import: "Optional");
|
||||
left.AppendColumns([
|
||||
$"\t{right.FieldName}_id: Mapped[Optional[int]] = mapped_column(ForeignKey(\"{right.TableName}.id\"))",
|
||||
$"\t{right.FieldName}: Mapped[\"{right.ClassName}\"] = relationship(back_populates=\"{left.FieldName.Pluralize()}\")"
|
||||
]);
|
||||
|
||||
left.AppendImportIfNotExists(
|
||||
"typing",
|
||||
"List",
|
||||
"from typing import List"
|
||||
);
|
||||
right.AppendColumn([
|
||||
left.Import(from: "typing", import: "List");
|
||||
|
||||
right.AppendColumns([
|
||||
$"\t{left.FieldName}_id: Mapped[int] = mapped_column(ForeignKey(\"{left.TableName}.id\"))",
|
||||
$"\t{left.FieldName.Pluralize()}: Mapped[List[\"{left.ClassName}\"]] = relationship(back_populates=\"{right.FieldName}\")"
|
||||
]);
|
||||
|
@ -4,21 +4,58 @@ namespace MycroForge.CLI.CodeGen;
|
||||
|
||||
public class OrmEnvInitializer : PythonSourceModifier
|
||||
{
|
||||
private PythonParser.Import_fromContext? _alembicImport;
|
||||
private PythonParser.AssignmentContext? _targetMetaDataAssignment;
|
||||
private PythonParser.AssignmentContext? _urlAssignmentContext;
|
||||
private PythonParser.AssignmentContext? _connectableAssignmentContext;
|
||||
|
||||
public OrmEnvInitializer(string source) : base(source)
|
||||
{
|
||||
}
|
||||
|
||||
public override string Rewrite()
|
||||
{
|
||||
var tree = Parser.file_input();
|
||||
Visit(tree);
|
||||
|
||||
if (_alembicImport is null)
|
||||
throw new Exception("Could not find import insertion point.");
|
||||
|
||||
if (_targetMetaDataAssignment is null)
|
||||
throw new Exception("Could not find metadata insertion point.");
|
||||
|
||||
if (_urlAssignmentContext is null)
|
||||
throw new Exception("Could not find url insertion point.");
|
||||
|
||||
if (_connectableAssignmentContext is null)
|
||||
throw new Exception("Could not find connectable insertion point.");
|
||||
|
||||
Rewrite(_alembicImport, [
|
||||
GetOriginalText(_alembicImport),
|
||||
"from orm.settings import OrmSettings",
|
||||
"from orm.entities.entity_base import EntityBase"
|
||||
]);
|
||||
|
||||
Rewrite(_targetMetaDataAssignment, "target_metadata = EntityBase.metadata");
|
||||
|
||||
Rewrite(_urlAssignmentContext, "url = OrmSettings.get_connectionstring()");
|
||||
|
||||
const string indent = " ";
|
||||
Rewrite(_connectableAssignmentContext, [
|
||||
"url = OrmSettings.get_connectionstring()",
|
||||
$"{indent}context.config.set_main_option('sqlalchemy.url', url)",
|
||||
$"{indent}{GetOriginalText(_connectableAssignmentContext)}"
|
||||
]);
|
||||
|
||||
return Rewriter.GetText();
|
||||
}
|
||||
|
||||
public override object? VisitImport_from(PythonParser.Import_fromContext context)
|
||||
{
|
||||
var text = GetOriginalText(context);
|
||||
|
||||
if (text != "from alembic import context") return null;
|
||||
|
||||
Rewrite(context,
|
||||
text,
|
||||
"from orm.settings import OrmSettings",
|
||||
"from orm.entities.entity_base import EntityBase"
|
||||
);
|
||||
if (text == "from alembic import context")
|
||||
_alembicImport = context;
|
||||
|
||||
return base.VisitImport_from(context);
|
||||
}
|
||||
@ -28,23 +65,13 @@ public class OrmEnvInitializer : PythonSourceModifier
|
||||
var text = GetOriginalText(context);
|
||||
|
||||
if (text == "target_metadata = None")
|
||||
{
|
||||
Rewrite(context, "target_metadata = EntityBase.metadata");
|
||||
}
|
||||
_targetMetaDataAssignment = context;
|
||||
|
||||
else if (text == "url = config.get_main_option(\"sqlalchemy.url\")")
|
||||
{
|
||||
Rewrite(context, "url = OrmSettings.get_connectionstring()");
|
||||
}
|
||||
_urlAssignmentContext = context;
|
||||
|
||||
else if (text.StartsWith("connectable ="))
|
||||
{
|
||||
// Important note, the indent here is 4 spaces and not tab(s).
|
||||
const string indent = " ";
|
||||
Rewrite(context, [
|
||||
"url = OrmSettings.get_connectionstring()",
|
||||
$"{indent}context.config.set_main_option('sqlalchemy.url', url)",
|
||||
$"{indent}{text}"
|
||||
]);
|
||||
}
|
||||
_connectableAssignmentContext = context;
|
||||
|
||||
return base.VisitAssignment(context);
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ public class OrmEnvUpdater : PythonSourceModifier
|
||||
{
|
||||
private readonly string _moduleName;
|
||||
private readonly string _className;
|
||||
private PythonParser.Import_fromContext? _lastImport;
|
||||
|
||||
public OrmEnvUpdater(string source, string moduleName, string className) : base(source)
|
||||
{
|
||||
@ -13,17 +14,31 @@ public class OrmEnvUpdater : PythonSourceModifier
|
||||
_className = className;
|
||||
}
|
||||
|
||||
public override object? VisitImport_from(PythonParser.Import_fromContext context)
|
||||
public override string Rewrite()
|
||||
{
|
||||
var text = GetOriginalText(context);
|
||||
var tree = Parser.file_input();
|
||||
Visit(tree);
|
||||
|
||||
if (text != "from orm.entities.entity_base import EntityBase") return null;
|
||||
if (_lastImport is null)
|
||||
throw new Exception("Could not find import insertion point.");
|
||||
|
||||
var text = GetOriginalText(_lastImport);
|
||||
|
||||
Rewrite(context, [
|
||||
Rewrite(_lastImport, [
|
||||
text,
|
||||
$"from orm.entities.{_moduleName} import {_className}"
|
||||
]);
|
||||
|
||||
return Rewriter.GetText();
|
||||
}
|
||||
|
||||
public override object? VisitImport_from(PythonParser.Import_fromContext context)
|
||||
{
|
||||
var text = GetOriginalText(context);
|
||||
|
||||
if (text.StartsWith("from orm.entities"))
|
||||
_lastImport = context;
|
||||
|
||||
return base.VisitImport_from(context);
|
||||
}
|
||||
}
|
@ -18,12 +18,7 @@ public abstract class PythonSourceModifier : PythonParserBaseVisitor<object?>
|
||||
Rewriter = new TokenStreamRewriter(Stream);
|
||||
}
|
||||
|
||||
public virtual string Rewrite()
|
||||
{
|
||||
var tree = Parser.file_input();
|
||||
Visit(tree);
|
||||
return Rewriter.GetText();
|
||||
}
|
||||
public abstract string Rewrite();
|
||||
|
||||
protected string GetOriginalText(ParserRuleContext context)
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user