mycroforge/MycroForge.CLI/Extensions/CommandExtensions.cs

96 lines
3.6 KiB
C#

using System.CommandLine;
using System.CommandLine.Builder;
using System.CommandLine.Invocation;
using System.CommandLine.Parsing;
using System.Reflection;
using MycroForge.CLI.Commands.Attributes;
namespace MycroForge.CLI.Extensions;
public static class CommandExtensions
{
public static async Task ExecuteAsync(this Commands.MycroForge rootCommand, string[] args)
{
var parser = new CommandLineBuilder(rootCommand)
.AddMiddleware()
.UseDefaults()
.UseExceptionHandler((ex, ctx) =>
{
/*
* Use a custom ExceptionHandler to prevent the System.CommandLine library from printing the StackTrace by default.
*/
Console.WriteLine(ex.Message);
/*
* Set the exit code to a non-zero value to indicate to the shell that the process has failed.
*/
Environment.ExitCode = -1;
})
.Build();
await parser.InvokeAsync(args);
}
private static CommandLineBuilder AddMiddleware(this CommandLineBuilder builder)
{
builder.AddMiddleware(async (context, next) =>
{
var commandChain = context.GetCommandChain();
var commandText = string.Join(' ', commandChain.Select(cmd => cmd.Name));
foreach (var _command in commandChain)
{
if (_command.GetRequiresFeatureAttribute() is RequiresFeatureAttribute requiresFeatureAttribute)
requiresFeatureAttribute.RequireFeature(commandText);
else if (_command.GetRequiresFileAttribute() is RequiresFileAttribute requiresFileAttribute)
requiresFileAttribute.RequireFile(commandText);
else if (_command.GetRequiresPluginAttribute() is RequiresPluginAttribute requiresPluginAttribute)
requiresPluginAttribute.RequirePluginProject(commandText);
else if (_command.GetRequiresVenvAttribute() is RequiresVenvAttribute requiresVenvAttribute)
requiresVenvAttribute.RequireVenv(commandText);
}
await next(context);
});
return builder;
}
private static List<Command> GetCommandChain(this InvocationContext context)
{
var chain = new List<Command>();
/*
* The CommandResult property refers to the last command in the chain.
* So if the command is 'm4g api run' the CommandResult will refer to 'run'.
*/
SymbolResult? cmd = context.ParseResult.CommandResult;
while (cmd is CommandResult result)
{
chain.Add(result.Command);
cmd = cmd.Parent;
}
/*
* Reverse the chain to reflect the actual order of the commands.
*/
chain.Reverse();
return chain;
}
private static RequiresFeatureAttribute? GetRequiresFeatureAttribute(this Command command) =>
command.GetType().GetCustomAttribute<RequiresFeatureAttribute>();
private static RequiresFileAttribute? GetRequiresFileAttribute(this Command command) =>
command.GetType().GetCustomAttribute<RequiresFileAttribute>();
private static RequiresPluginAttribute? GetRequiresPluginAttribute(this Command command) =>
command.GetType().GetCustomAttribute<RequiresPluginAttribute>();
private static RequiresVenvAttribute? GetRequiresVenvAttribute(this Command command) =>
command.GetType().GetCustomAttribute<RequiresVenvAttribute>();
}