From dd81fecdd680ea0f9fade13ef8812eb3760c9c05 Mon Sep 17 00:00:00 2001 From: JustFixMe Date: Mon, 8 Apr 2024 22:46:34 +0400 Subject: [PATCH] added Extend extension methods generation --- .../ExtensionsMethodGenerator.cs | 1 + .../ResultAppendExecutor.cs | 8 - .../ResultExtendExecutor.cs | 160 ++++++++++++++++++ .../ResultExtensionsExecutor.cs | 7 + Railway/Error.cs | 4 +- Railway/ErrorJsonConverter.cs | 4 +- Railway/Result.cs | 22 ++- Raliway.Tests/Results/GeneralUsage.cs | 49 +++++- 8 files changed, 235 insertions(+), 20 deletions(-) create mode 100644 Railway.SourceGenerator/ResultExtendExecutor.cs diff --git a/Railway.SourceGenerator/ExtensionsMethodGenerator.cs b/Railway.SourceGenerator/ExtensionsMethodGenerator.cs index b7298be..28e4970 100644 --- a/Railway.SourceGenerator/ExtensionsMethodGenerator.cs +++ b/Railway.SourceGenerator/ExtensionsMethodGenerator.cs @@ -13,6 +13,7 @@ public class ExtensionsMethodGenerator : IIncrementalGenerator new ResultMapExecutor(), new ResultBindExecutor(), new ResultTapExecutor(), + new ResultExtendExecutor(), new ResultTryRecoverExecutor(), new ResultAppendExecutor(), new TryExtensionsExecutor(), diff --git a/Railway.SourceGenerator/ResultAppendExecutor.cs b/Railway.SourceGenerator/ResultAppendExecutor.cs index f0364cc..41679ea 100644 --- a/Railway.SourceGenerator/ResultAppendExecutor.cs +++ b/Railway.SourceGenerator/ResultAppendExecutor.cs @@ -460,12 +460,4 @@ internal sealed class ResultAppendExecutor : ResultExtensionsExecutor } """); } - - internal static string JoinArguments(string arg1, string arg2) => (arg1, arg2) switch - { - ("", "") => "", - (string arg, "") => arg, - ("", string arg) => arg, - _ => $"{arg1}, {arg2}" - }; } diff --git a/Railway.SourceGenerator/ResultExtendExecutor.cs b/Railway.SourceGenerator/ResultExtendExecutor.cs new file mode 100644 index 0000000..55ee7ed --- /dev/null +++ b/Railway.SourceGenerator/ResultExtendExecutor.cs @@ -0,0 +1,160 @@ +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Just.Railway.SourceGen; + +internal sealed class ResultExtendExecutor : ResultExtensionsExecutor +{ + protected override string ExtensionType => "Extend"; + + protected override void GenerateMethodsForArgCount(StringBuilder sb, int argCount) + { + if (argCount == 0 || argCount == Constants.MaxResultTupleSize) + { + return; + } + + var templateArgNames = Enumerable.Range(1, argCount) + .Select(i => $"T{i}") + .ToImmutableArray(); + + var expandedTemplateArgNames = templateArgNames.Add("R"); + string resultTypeDef = GenerateResultTypeDef(templateArgNames); + string resultValueExpansion = GenerateResultValueExpansion(templateArgNames); + string resultExpandedTypeDef = GenerateResultTypeDef(expandedTemplateArgNames); + string methodTemplateDecl = GenerateTemplateDecl(expandedTemplateArgNames); + string bindTemplateDecl = GenerateTemplateDecl(templateArgNames.Add("Result")); + + sb.AppendLine($"#region {resultTypeDef}"); + + sb.AppendLine($$""" + [PureAttribute] + [GeneratedCodeAttribute("{{nameof(ResultExtendExecutor)}}", "1.0.0.0")] + public static {{resultExpandedTypeDef}} Extend{{methodTemplateDecl}}(this in {{resultTypeDef}} result, Func{{bindTemplateDecl}} extensionFunc) + { + if (result.State == ResultState.Bottom) + { + throw new ResultNotInitializedException(nameof(result)); + } + else if (result.IsFailure) + { + return result.Error!; + } + + var extension = extensionFunc({{resultValueExpansion}}); + if (extension.State == ResultState.Bottom) + { + throw new ResultNotInitializedException(nameof(extensionFunc)); + } + else if (extension.IsFailure) + { + return extension.Error!; + } + + return Result.Success({{JoinArguments(resultValueExpansion, "extension.Value")}}); + } + """); + + GenerateAsyncMethods("Task", sb, templateArgNames, resultTypeDef, resultValueExpansion); + GenerateAsyncMethods("ValueTask", sb, templateArgNames, resultTypeDef, resultValueExpansion); + + sb.AppendLine("#endregion"); + } + + private static void GenerateAsyncMethods(string taskType, StringBuilder sb, ImmutableArray templateArgNames, string resultTypeDef, string resultValueExpansion) + { + var expandedTemplateArgNames = templateArgNames.Add("R"); + string resultExpandedTypeDef = GenerateResultTypeDef(expandedTemplateArgNames); + string methodTemplateDecl = GenerateTemplateDecl(expandedTemplateArgNames); + string bindTemplateDecl = GenerateTemplateDecl(templateArgNames.Add("Result")); + string asyncActionTemplateDecl = GenerateTemplateDecl(templateArgNames.Add($"{taskType}>")); + + sb.AppendLine($$""" + [PureAttribute] + [GeneratedCodeAttribute("{{nameof(ResultExtendExecutor)}}", "1.0.0.0")] + public static async {{taskType}}<{{resultExpandedTypeDef}}> Extend{{methodTemplateDecl}}(this {{taskType}}<{{resultTypeDef}}> resultTask, Func{{bindTemplateDecl}} extensionFunc) + { + var result = await resultTask.ConfigureAwait(false); + if (result.State == ResultState.Bottom) + { + throw new ResultNotInitializedException(nameof(resultTask)); + } + else if (result.IsFailure) + { + return result.Error!; + } + + var extension = extensionFunc({{resultValueExpansion}}); + if (extension.State == ResultState.Bottom) + { + throw new ResultNotInitializedException(nameof(extensionFunc)); + } + else if (extension.IsFailure) + { + return extension.Error!; + } + + return Result.Success({{JoinArguments(resultValueExpansion, "extension.Value")}}); + } + """); + + sb.AppendLine($$""" + [PureAttribute] + [GeneratedCodeAttribute("{{nameof(ResultExtendExecutor)}}", "1.0.0.0")] + public static async {{taskType}}<{{resultExpandedTypeDef}}> Extend{{methodTemplateDecl}}(this {{resultTypeDef}} result, Func{{asyncActionTemplateDecl}} extensionFunc) + { + if (result.State == ResultState.Bottom) + { + throw new ResultNotInitializedException(nameof(result)); + } + else if (result.IsFailure) + { + return result.Error!; + } + + var extension = await extensionFunc({{resultValueExpansion}}).ConfigureAwait(false); + if (extension.State == ResultState.Bottom) + { + throw new ResultNotInitializedException(nameof(extensionFunc)); + } + else if (extension.IsFailure) + { + return extension.Error!; + } + + return Result.Success({{JoinArguments(resultValueExpansion, "extension.Value")}}); + } + """); + + sb.AppendLine($$""" + [PureAttribute] + [GeneratedCodeAttribute("{{nameof(ResultExtendExecutor)}}", "1.0.0.0")] + public static async {{taskType}}<{{resultExpandedTypeDef}}> Extend{{methodTemplateDecl}}(this {{taskType}}<{{resultTypeDef}}> resultTask, Func{{asyncActionTemplateDecl}} extensionFunc) + { + var result = await resultTask.ConfigureAwait(false); + if (result.State == ResultState.Bottom) + { + throw new ResultNotInitializedException(nameof(resultTask)); + } + else if (result.IsFailure) + { + return result.Error!; + } + + var extension = await extensionFunc({{resultValueExpansion}}).ConfigureAwait(false); + if (extension.State == ResultState.Bottom) + { + throw new ResultNotInitializedException(nameof(extensionFunc)); + } + else if (extension.IsFailure) + { + return extension.Error!; + } + + return Result.Success({{JoinArguments(resultValueExpansion, "extension.Value")}}); + } + """); + } +} diff --git a/Railway.SourceGenerator/ResultExtensionsExecutor.cs b/Railway.SourceGenerator/ResultExtensionsExecutor.cs index b8df66e..c96656d 100644 --- a/Railway.SourceGenerator/ResultExtensionsExecutor.cs +++ b/Railway.SourceGenerator/ResultExtensionsExecutor.cs @@ -50,6 +50,13 @@ internal abstract class ResultExtensionsExecutor : IGeneratorExecutor 1 => $"Result<{string.Join(", ", templateArgNames)}>", _ => $"Result<({string.Join(", ", templateArgNames)})>", }; + protected static string JoinArguments(string arg1, string arg2) => (arg1, arg2) switch + { + ("", "") => "", + (string arg, "") => arg, + ("", string arg) => arg, + _ => $"{arg1}, {arg2}" + }; protected static string GenerateResultValueExpansion(ImmutableArray templateArgNames) { diff --git a/Railway/Error.cs b/Railway/Error.cs index fa723e7..66b8298 100644 --- a/Railway/Error.cs +++ b/Railway/Error.cs @@ -144,7 +144,7 @@ public abstract class Error : IEquatable, IComparable message = Message; } - [Pure] internal virtual Error AccessUnsafe(int position) => this; + [Pure, MethodImpl(MethodImplOptions.AggressiveInlining)] internal virtual Error AccessUnsafe(int position) => this; } [JsonConverter(typeof(ExpectedErrorJsonConverter))] @@ -393,7 +393,7 @@ public sealed class ManyErrors : Error, IEnumerable, IReadOnlyList errors.Add(error); } - [Pure] internal override Error AccessUnsafe(int position) => _errors[position]; + [Pure, MethodImpl(MethodImplOptions.AggressiveInlining)] internal override Error AccessUnsafe(int position) => _errors[position]; } [Serializable] diff --git a/Railway/ErrorJsonConverter.cs b/Railway/ErrorJsonConverter.cs index 5665875..83d6441 100644 --- a/Railway/ErrorJsonConverter.cs +++ b/Railway/ErrorJsonConverter.cs @@ -68,7 +68,7 @@ public sealed class ErrorJsonConverter : JsonConverter if (!(reader.TokenType == JsonTokenType.String)) throw new JsonException("Unable to deserialize Error type."); - + var propvalue = reader.GetString(); if (string.IsNullOrEmpty(propvalue)) break; @@ -84,7 +84,7 @@ public sealed class ErrorJsonConverter : JsonConverter else if (!string.IsNullOrEmpty(propname)) { extensionData ??= ImmutableDictionary.CreateBuilder(); - extensionData.Add(propname, propvalue); + extensionData[propname] = propvalue; } break; diff --git a/Railway/Result.cs b/Railway/Result.cs index 16d32fb..eac49ef 100644 --- a/Railway/Result.cs +++ b/Railway/Result.cs @@ -7,7 +7,13 @@ internal enum ResultState : byte public readonly partial struct Result : IEquatable { - internal SuccessUnit Value => new(); + [SuppressMessage("Performance", "CA1822:Mark members as static", Justification = "Simplified source generation")] + internal SuccessUnit Value + { + [Pure, MethodImpl(MethodImplOptions.AggressiveInlining)] + get => new(); + } + internal readonly Error? Error; internal readonly ResultState State; @@ -36,12 +42,18 @@ public readonly partial struct Result : IEquatable public static Result<(T1, T2, T3, T4, T5)> Success(T1 value1, T2 value2, T3 value3, T4 value4, T5 value5) => new((value1, value2, value3, value4, value5)); + [Pure, MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Result Failure(string error) => Error.New(error ?? throw new ArgumentNullException(nameof(error))); + [Pure, MethodImpl(MethodImplOptions.AggressiveInlining)] public static Result Failure(Error error) => new(error ?? throw new ArgumentNullException(nameof(error))); [Pure, MethodImpl(MethodImplOptions.AggressiveInlining)] public static Result Failure(Exception exception) => new(Error.New(exception) ?? throw new ArgumentNullException(nameof(exception))); + [Pure, MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Result Failure(string error) => Error.New(error ?? throw new ArgumentNullException(nameof(error))); + [Pure, MethodImpl(MethodImplOptions.AggressiveInlining)] public static Result Failure(Error error) => new(error ?? throw new ArgumentNullException(nameof(error))); @@ -66,13 +78,13 @@ public readonly partial struct Result : IEquatable [Pure] public bool IsSuccess => Error is null; [Pure] public bool IsFailure => Error is not null; - [Pure] public bool Success([MaybeNullWhen(false)]out SuccessUnit? u, [MaybeNullWhen(true), NotNullWhen(false)]out Error? error) + [Pure] public bool TryGetValue([MaybeNullWhen(false)]out SuccessUnit? u, [MaybeNullWhen(true), NotNullWhen(false)]out Error? error) { switch (State) { case ResultState.Success: u = new SuccessUnit(); - error = default; + error = null; return true; case ResultState.Error: @@ -161,13 +173,13 @@ public readonly struct Result : IEquatable> [Pure] public bool IsSuccess => State == ResultState.Success; [Pure] public bool IsFailure => State == ResultState.Error; - [Pure] public bool Success([MaybeNullWhen(false)]out T value, [MaybeNullWhen(true), NotNullWhen(false)]out Error? error) + [Pure] public bool TryGetValue([MaybeNullWhen(false)]out T value, [MaybeNullWhen(true), NotNullWhen(false)]out Error? error) { switch (State) { case ResultState.Success: value = Value; - error = default; + error = null; return true; case ResultState.Error: diff --git a/Raliway.Tests/Results/GeneralUsage.cs b/Raliway.Tests/Results/GeneralUsage.cs index 3f2d35e..998b318 100644 --- a/Raliway.Tests/Results/GeneralUsage.cs +++ b/Raliway.Tests/Results/GeneralUsage.cs @@ -22,7 +22,7 @@ public class GeneralUsage return ""; } ); - + Assert.Equal("TEST_1;SOME", result); } @@ -31,7 +31,7 @@ public class GeneralUsage { // Given var error = Error.New("test"); - + // When var result = Result.Success() @@ -94,7 +94,7 @@ public class GeneralUsage { // Given var error = Error.New("test"); - + // When var result = await Result.Success() @@ -169,4 +169,47 @@ public class GeneralUsage Assert.True(result.IsFailure); Assert.Equal(error, result.Error); } + + [Fact] + public void WhenExtendingSuccessWithSuccess_ShouldReturnSuccess() + { + var success = Result.Success(1) + .Append("2"); + + var result = success + .Extend((i, s) => Result.Success($"{i} + {s}")); + + Assert.True(result.IsSuccess); + Assert.Equal((1, "2", "1 + 2"), result.Value); + } + + [Fact] + public void WhenExtendingFailureWithSuccess_ShouldNotEvaluateExtension() + { + var failure = Result.Success(1) + .Append(Result.Failure("failure")); + + var result = failure + .Extend((i, s) => + { + Assert.Fail(); + return Result.Success(""); + }); + + Assert.True(result.IsFailure); + Assert.Equal(Error.New("failure"), result.Error); + } + + [Fact] + public void WhenExtendingSuccessWithFailure_ShouldReturnFailure() + { + var success = Result.Success(1) + .Append("2"); + + var result = success + .Extend((i, s) => Result.Failure("failure")); + + Assert.True(result.IsFailure); + Assert.Equal(Error.New("failure"), result.Error); + } }