diff --git a/src/Spectre.Console.Analyzer/Fixes/CodeActions/SwitchToAnsiConsoleAction.cs b/src/Spectre.Console.Analyzer/Fixes/CodeActions/SwitchToAnsiConsoleAction.cs
index c2930fb..d6802bc 100644
--- a/src/Spectre.Console.Analyzer/Fixes/CodeActions/SwitchToAnsiConsoleAction.cs
+++ b/src/Spectre.Console.Analyzer/Fixes/CodeActions/SwitchToAnsiConsoleAction.cs
@@ -1,4 +1,5 @@
-using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+using Microsoft.CodeAnalysis.Editing;
+using Microsoft.CodeAnalysis.Simplification;
namespace Spectre.Console.Analyzer.CodeActions;
@@ -31,90 +32,151 @@ public class SwitchToAnsiConsoleAction : CodeAction
///
protected override async Task GetChangedDocumentAsync(CancellationToken cancellationToken)
- {
- var originalCaller = ((MemberAccessExpressionSyntax)_originalInvocation.Expression).Name.ToString();
-
- var syntaxTree = await _document.GetSyntaxTreeAsync(cancellationToken).ConfigureAwait(false);
- if (syntaxTree == null)
- {
- return _document;
- }
-
- var root = (CompilationUnitSyntax)await syntaxTree.GetRootAsync(cancellationToken).ConfigureAwait(false);
-
- // If there is an ansiConsole passed into the method then we'll use it.
- // otherwise we'll check for a field level instance.
- // if neither of those exist we'll fall back to the static param.
- var ansiConsoleParameterDeclaration = GetAnsiConsoleParameterDeclaration();
- var ansiConsoleFieldIdentifier = GetAnsiConsoleFieldDeclaration();
- var ansiConsoleIdentifier = ansiConsoleParameterDeclaration ??
- ansiConsoleFieldIdentifier ??
- Constants.StaticInstance;
-
- // Replace the System.Console call with a call to the identifier above.
- var newRoot = root.ReplaceNode(
- _originalInvocation,
- GetImportedSpectreCall(originalCaller, ansiConsoleIdentifier));
-
- // If we are calling the static instance and Spectre isn't imported yet we should do so.
- if (ansiConsoleIdentifier == Constants.StaticInstance && root.Usings.ToList().All(i => i.Name.ToString() != Constants.SpectreConsole))
- {
- newRoot = newRoot.AddUsings(Syntax.SpectreUsing);
- }
-
- return _document.WithSyntaxRoot(newRoot);
- }
-
- private string? GetAnsiConsoleParameterDeclaration()
- {
- return _originalInvocation
- .Ancestors().OfType()
- .FirstOrDefault()
- ?.ParameterList.Parameters
- .FirstOrDefault(i => i.Type?.NormalizeWhitespace()?.ToString() == "IAnsiConsole")
- ?.Identifier.Text;
- }
-
- private string? GetAnsiConsoleFieldDeclaration()
- {
- // let's look to see if our call is in a static method.
- // if so we'll only want to look for static IAnsiConsoles
- // and vice-versa if we aren't.
- // If there is no parent method, the SyntaxNode should be in
- // a top-level statement, so there is no field anyway.
- var isStatic = _originalInvocation
- .Ancestors()
- .OfType()
- .FirstOrDefault()
- ?.Modifiers.Any(i => i.IsKind(SyntaxKind.StaticKeyword));
+ {
+ var editor = await DocumentEditor.CreateAsync(_document, cancellationToken).ConfigureAwait(false);
+ var compilation = editor.SemanticModel.Compilation;
- if (isStatic == null)
+ var operation = editor.SemanticModel.GetOperation(_originalInvocation, cancellationToken) as IInvocationOperation;
+ if (operation == null)
{
- return null;
- }
-
- return _originalInvocation
- .Ancestors().OfType()
- .First()
- .Members
- .OfType()
- .FirstOrDefault(i =>
- i.Declaration.Type.NormalizeWhitespace().ToString() == "IAnsiConsole" &&
- (!isStatic.GetValueOrDefault() ^ i.Modifiers.Any(modifier => modifier.IsKind(SyntaxKind.StaticKeyword))))
- ?.Declaration.Variables.First().Identifier.Text;
- }
-
- private ExpressionSyntax GetImportedSpectreCall(string originalCaller, string ansiConsoleIdentifier)
- {
- return ExpressionStatement(
- InvocationExpression(
- MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- IdentifierName(ansiConsoleIdentifier),
- IdentifierName(originalCaller)))
- .WithArgumentList(_originalInvocation.ArgumentList)
- .WithTrailingTrivia(_originalInvocation.GetTrailingTrivia())
- .WithLeadingTrivia(_originalInvocation.GetLeadingTrivia()))
- .Expression;
+ return _document;
+ }
+
+ // If there is an IAnsiConsole passed into the method then we'll use it.
+ // otherwise we'll check for a field level instance.
+ // if neither of those exist we'll fall back to the static param.
+ var spectreConsoleSymbol = compilation.GetTypeByMetadataName("Spectre.Console.AnsiConsole");
+ var iansiConsoleSymbol = compilation.GetTypeByMetadataName("Spectre.Console.IAnsiConsole");
+
+ ISymbol? accessibleConsoleSymbol = spectreConsoleSymbol;
+ if (iansiConsoleSymbol != null)
+ {
+ var isInStaticContext = IsInStaticContext(operation, cancellationToken, out var parentStaticMemberStartPosition);
+
+ foreach (var symbol in editor.SemanticModel.LookupSymbols(operation.Syntax.GetLocation().SourceSpan.Start))
+ {
+ // LookupSymbols check the accessibility of the symbol, but it can
+ // suggest instance members when the current context is static.
+ var symbolType = symbol switch
+ {
+ IParameterSymbol parameter => parameter.Type,
+ IFieldSymbol field when !isInStaticContext || field.IsStatic => field.Type,
+ IPropertySymbol { GetMethod: not null } property when !isInStaticContext || property.IsStatic => property.Type,
+ ILocalSymbol local => local.Type,
+ _ => null,
+ };
+
+ // Locals can be returned even if there are not valid in the current context. For instance,
+ // it can return locals declared after the current location. Or it can return locals that
+ // should not be accessible in a static local function.
+ //
+ // void Sample()
+ // {
+ // int local = 0;
+ // static void LocalFunction() => local; <-- local is invalid here but LookupSymbols suggests it
+ // }
+ if (symbol.Kind is SymbolKind.Local)
+ {
+ var localPosition = symbol.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax(cancellationToken).GetLocation().SourceSpan.Start;
+
+ // The local is not part of the source tree
+ if (localPosition == null)
+ {
+ break;
+ }
+
+ // The local is declared after the current expression
+ if (localPosition > _originalInvocation.Span.Start)
+ {
+ break;
+ }
+
+ // The local is declared outside the static local function
+ if (isInStaticContext && localPosition < parentStaticMemberStartPosition)
+ {
+ break;
+ }
+ }
+
+ if (IsOrImplementSymbol(symbolType, iansiConsoleSymbol))
+ {
+ accessibleConsoleSymbol = symbol;
+ break;
+ }
+ }
+ }
+
+ if (accessibleConsoleSymbol == null)
+ {
+ return _document;
+ }
+
+ // Replace the original invocation
+ var generator = editor.Generator;
+ var consoleExpression = accessibleConsoleSymbol switch
+ {
+ ITypeSymbol typeSymbol => generator.TypeExpression(typeSymbol, addImport: true).WithAdditionalAnnotations(Simplifier.AddImportsAnnotation),
+ _ => generator.IdentifierName(accessibleConsoleSymbol.Name),
+ };
+
+ var newExpression = generator.InvocationExpression(generator.MemberAccessExpression(consoleExpression, operation.TargetMethod.Name), _originalInvocation.ArgumentList.Arguments)
+ .WithLeadingTrivia(_originalInvocation.GetLeadingTrivia())
+ .WithTrailingTrivia(_originalInvocation.GetTrailingTrivia());
+
+ editor.ReplaceNode(_originalInvocation, newExpression);
+
+ return editor.GetChangedDocument();
+ }
+
+ private static bool IsOrImplementSymbol(ITypeSymbol? symbol, ITypeSymbol interfaceSymbol)
+ {
+ if (symbol == null)
+ {
+ return false;
+ }
+
+ if (SymbolEqualityComparer.Default.Equals(symbol, interfaceSymbol))
+ {
+ return true;
+ }
+
+ foreach (var iface in symbol.AllInterfaces)
+ {
+ if (SymbolEqualityComparer.Default.Equals(iface, interfaceSymbol))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ private static bool IsInStaticContext(IOperation operation, CancellationToken cancellationToken, out int parentStaticMemberStartPosition)
+ {
+ // Local functions can be nested, and an instance local function can be declared
+ // in a static local function. So, you need to continue to check ancestors when a
+ // local function is not static.
+ foreach (var member in operation.Syntax.Ancestors())
+ {
+ if (member is LocalFunctionStatementSyntax localFunction)
+ {
+ var symbol = operation.SemanticModel!.GetDeclaredSymbol(localFunction, cancellationToken);
+ if (symbol != null && symbol.IsStatic)
+ {
+ parentStaticMemberStartPosition = localFunction.GetLocation().SourceSpan.Start;
+ return true;
+ }
+ }
+ else if (member is MethodDeclarationSyntax methodDeclaration)
+ {
+ parentStaticMemberStartPosition = methodDeclaration.GetLocation().SourceSpan.Start;
+
+ var symbol = operation.SemanticModel!.GetDeclaredSymbol(methodDeclaration, cancellationToken);
+ return symbol != null && symbol.IsStatic;
+ }
+ }
+
+ parentStaticMemberStartPosition = -1;
+ return false;
}
}
\ No newline at end of file
diff --git a/test/Spectre.Console.Analyzer.Tests/Unit/Fixes/UseSpectreInsteadOfSystemConsoleFixTests.cs b/test/Spectre.Console.Analyzer.Tests/Unit/Fixes/UseSpectreInsteadOfSystemConsoleFixTests.cs
index 1e0f615..d6b0aaf 100644
--- a/test/Spectre.Console.Analyzer.Tests/Unit/Fixes/UseSpectreInsteadOfSystemConsoleFixTests.cs
+++ b/test/Spectre.Console.Analyzer.Tests/Unit/Fixes/UseSpectreInsteadOfSystemConsoleFixTests.cs
@@ -104,6 +104,74 @@ class TestClass
.ConfigureAwait(false);
}
+ [Fact]
+ public async Task SystemConsole_replaced_with_local_variable_AnsiConsole()
+ {
+ const string Source = @"
+using System;
+using Spectre.Console;
+
+class TestClass
+{
+ void TestMethod()
+ {
+ IAnsiConsole ansiConsole = null;
+ Console.WriteLine(""Hello, World"");
+ }
+}";
+
+ const string FixedSource = @"
+using System;
+using Spectre.Console;
+
+class TestClass
+{
+ void TestMethod()
+ {
+ IAnsiConsole ansiConsole = null;
+ ansiConsole.WriteLine(""Hello, World"");
+ }
+}";
+
+ await SpectreAnalyzerVerifier
+ .VerifyCodeFixAsync(Source, _expectedDiagnostic.WithLocation(10, 9), FixedSource)
+ .ConfigureAwait(false);
+ }
+
+ [Fact]
+ public async Task SystemConsole_not_replaced_with_local_variable_declared_after_the_call()
+ {
+ const string Source = @"
+using System;
+using Spectre.Console;
+
+class TestClass
+{
+ void TestMethod()
+ {
+ Console.WriteLine(""Hello, World"");
+ IAnsiConsole ansiConsole;
+ }
+}";
+
+ const string FixedSource = @"
+using System;
+using Spectre.Console;
+
+class TestClass
+{
+ void TestMethod()
+ {
+ AnsiConsole.WriteLine(""Hello, World"");
+ IAnsiConsole ansiConsole;
+ }
+}";
+
+ await SpectreAnalyzerVerifier
+ .VerifyCodeFixAsync(Source, _expectedDiagnostic.WithLocation(9, 9), FixedSource)
+ .ConfigureAwait(false);
+ }
+
[Fact]
public async Task SystemConsole_replaced_with_static_field_AnsiConsole()
{
@@ -140,6 +208,108 @@ class TestClass
.ConfigureAwait(false);
}
+ [Fact]
+ public async Task SystemConsole_replaced_with_AnsiConsole_when_field_is_not_static()
+ {
+ const string Source = @"
+using System;
+using Spectre.Console;
+
+class TestClass
+{
+ IAnsiConsole _ansiConsole;
+
+ static void TestMethod()
+ {
+ Console.WriteLine(""Hello, World"");
+ }
+}";
+
+ const string FixedSource = @"
+using System;
+using Spectre.Console;
+
+class TestClass
+{
+ IAnsiConsole _ansiConsole;
+
+ static void TestMethod()
+ {
+ AnsiConsole.WriteLine(""Hello, World"");
+ }
+}";
+
+ await SpectreAnalyzerVerifier
+ .VerifyCodeFixAsync(Source, _expectedDiagnostic.WithLocation(11, 9), FixedSource)
+ .ConfigureAwait(false);
+ }
+
+ [Fact]
+ public async Task SystemConsole_replaced_with_AnsiConsole_from_local_function_parameter()
+ {
+ const string Source = @"
+using System;
+using Spectre.Console;
+
+class TestClass
+{
+ static void TestMethod()
+ {
+ static void LocalFunction(IAnsiConsole ansiConsole) => Console.WriteLine(""Hello, World"");
+ }
+}";
+
+ const string FixedSource = @"
+using System;
+using Spectre.Console;
+
+class TestClass
+{
+ static void TestMethod()
+ {
+ static void LocalFunction(IAnsiConsole ansiConsole) => ansiConsole.WriteLine(""Hello, World"");
+ }
+}";
+
+ await SpectreAnalyzerVerifier
+ .VerifyCodeFixAsync(Source, _expectedDiagnostic.WithLocation(9, 64), FixedSource)
+ .ConfigureAwait(false);
+ }
+
+ [Fact]
+ public async Task SystemConsole_do_not_use_variable_from_parent_method_in_static_local_function()
+ {
+ const string Source = @"
+using System;
+using Spectre.Console;
+
+class TestClass
+{
+ static void TestMethod()
+ {
+ IAnsiConsole ansiConsole = null;
+ static void LocalFunction() => Console.WriteLine(""Hello, World"");
+ }
+}";
+
+ const string FixedSource = @"
+using System;
+using Spectre.Console;
+
+class TestClass
+{
+ static void TestMethod()
+ {
+ IAnsiConsole ansiConsole = null;
+ static void LocalFunction() => AnsiConsole.WriteLine(""Hello, World"");
+ }
+}";
+
+ await SpectreAnalyzerVerifier
+ .VerifyCodeFixAsync(Source, _expectedDiagnostic.WithLocation(10, 40), FixedSource)
+ .ConfigureAwait(false);
+ }
+
[Fact]
public async Task SystemConsole_replaced_with_AnsiConsole_in_top_level_statements()
{