From f9566af7d100fb9e3964bd890910fca3b7ce2673 Mon Sep 17 00:00:00 2001 From: 28810 <28810@YEXIANGQIN> Date: Fri, 10 Apr 2020 11:58:47 +0800 Subject: [PATCH] update IQueryable provider --- .../FreeSql.Extensions.Linq/ExprHelper.cs | 95 +++++++++++++++++++ .../QueryableProvider.cs | 17 ++-- .../FreeSql.Tests/Queryable/ExprHelperTest.cs | 63 ++++++++++++ 3 files changed, 166 insertions(+), 9 deletions(-) create mode 100644 Extensions/FreeSql.Extensions.Linq/ExprHelper.cs create mode 100644 FreeSql.Tests/FreeSql.Tests/Queryable/ExprHelperTest.cs diff --git a/Extensions/FreeSql.Extensions.Linq/ExprHelper.cs b/Extensions/FreeSql.Extensions.Linq/ExprHelper.cs new file mode 100644 index 00000000..75d2e5aa --- /dev/null +++ b/Extensions/FreeSql.Extensions.Linq/ExprHelper.cs @@ -0,0 +1,95 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq.Expressions; +using System.Linq; +using System.Reflection; + +namespace FreeSql.Extensions.Linq +{ + public static class ExprHelper + { + public static object GetConstExprValue(this Expression exp) + { + if (exp.IsParameter()) return null; + + var expStack = new Stack(); + var exp2 = exp; + while (true) + { + switch (exp2?.NodeType) + { + case ExpressionType.Constant: + expStack.Push(exp2); + break; + case ExpressionType.MemberAccess: + expStack.Push(exp2); + exp2 = (exp2 as MemberExpression).Expression; + if (exp2 == null) break; + continue; + case ExpressionType.Call: + return Expression.Lambda(exp).Compile().DynamicInvoke(); + case ExpressionType.TypeAs: + case ExpressionType.Convert: + var oper2 = (exp2 as UnaryExpression).Operand; + if (oper2.NodeType == ExpressionType.Parameter) + { + var oper2Parm = oper2 as ParameterExpression; + expStack.Push(exp2.Type.IsAbstract || exp2.Type.IsInterface ? oper2Parm : Expression.Parameter(exp2.Type, oper2Parm.Name)); + } + else + expStack.Push(oper2); + break; + } + break; + } + object firstValue = null; + switch (expStack.First().NodeType) + { + case ExpressionType.Constant: + var expStackFirst = expStack.Pop() as ConstantExpression; + firstValue = expStackFirst?.Value; + break; + case ExpressionType.MemberAccess: + var expStackFirstMem = expStack.First() as MemberExpression; + if (expStackFirstMem.Expression?.NodeType == ExpressionType.Constant) + firstValue = (expStackFirstMem.Expression as ConstantExpression)?.Value; + else + return Expression.Lambda(exp).Compile().DynamicInvoke(); + break; + } + while (expStack.Any()) + { + var expStackItem = expStack.Pop(); + switch (expStackItem.NodeType) + { + case ExpressionType.MemberAccess: + var memExp = expStackItem as MemberExpression; + if (memExp.Member.MemberType == MemberTypes.Property) + firstValue = ((PropertyInfo)memExp.Member).GetValue(firstValue, null); + else if (memExp.Member.MemberType == MemberTypes.Field) + firstValue = ((FieldInfo)memExp.Member).GetValue(firstValue); + break; + } + } + return firstValue; + } + + public static bool IsParameter(this Expression exp) + { + var test = new TestParameterExpressionVisitor(); + test.Visit(exp); + return test.Result; + } + internal class TestParameterExpressionVisitor : ExpressionVisitor + { + public bool Result { get; private set; } + + protected override Expression VisitParameter(ParameterExpression node) + { + if (!Result) Result = true; + return node; + } + } + } +} diff --git a/Extensions/FreeSql.Extensions.Linq/QueryableProvider.cs b/Extensions/FreeSql.Extensions.Linq/QueryableProvider.cs index 892e1cf8..b65147b1 100644 --- a/Extensions/FreeSql.Extensions.Linq/QueryableProvider.cs +++ b/Extensions/FreeSql.Extensions.Linq/QueryableProvider.cs @@ -81,7 +81,7 @@ namespace FreeSql.Extensions.Linq { callExp = stackCallExps.Pop(); TResult throwCallExp(string message) => throw new Exception($"FreeSql Queryable 解析出错,执行的方法 {callExp.Method.Name} {message}"); - if (callExp.Method.DeclaringType != typeof(Queryable)) return throwCallExp($"必须属于 System.Linq.Enumerable"); + if (callExp.Method.DeclaringType != typeof(Queryable)) return throwCallExp($"必须属于 System.Linq.Queryable"); TResult tplMaxMinAvgSum(string method) { if (callExp.Arguments.Count == 2) @@ -105,7 +105,7 @@ namespace FreeSql.Extensions.Linq switch (callExp.Method.Name) { case "Any": - if (callExp.Arguments.Count == 2) _select.Where((Expression>)(callExp.Arguments[1] as UnaryExpression)?.Operand); + if (callExp.Arguments.Count == 2) _select.InternalWhere(callExp.Arguments[1]); return (TResult)(object)_select.Any(); case "AsQueryable": break; @@ -120,14 +120,14 @@ namespace FreeSql.Extensions.Linq case "Contains": if (callExp.Arguments.Count == 2) { - var dywhere = (callExp.Arguments[1] as ConstantExpression)?.Value as TSource; + var dywhere = callExp.Arguments[1].GetConstExprValue(); if (dywhere == null) return throwCallExp($" 参数值不能为 null"); _select.WhereDynamic(dywhere); return (TResult)(object)_select.Any(); } return throwCallExp($" 不支持 {callExp.Arguments.Count}个参数的方法"); case "Count": - if (callExp.Arguments.Count == 2) _select.Where((Expression>)(callExp.Arguments[1] as UnaryExpression)?.Operand); + if (callExp.Arguments.Count == 2) _select.InternalWhere(callExp.Arguments[1]); return (TResult)Utils.GetDataReaderValue(typeof(TResult), _select.Count()); case "Distinct": @@ -140,7 +140,7 @@ namespace FreeSql.Extensions.Linq case "ElementAt": case "ElementAtOrDefault": - _select.Offset((int)(callExp.Arguments[1] as ConstantExpression)?.Value); + _select.Offset((int)callExp.Arguments[1].GetConstExprValue()); _select.Limit(1); isfirst = true; break; @@ -148,7 +148,7 @@ namespace FreeSql.Extensions.Linq case "FirstOrDefault": case "Single": case "SingleOrDefault": - if (callExp.Arguments.Count == 2) _select.Where((Expression>)(callExp.Arguments[1] as UnaryExpression)?.Operand); + if (callExp.Arguments.Count == 2) _select.InternalWhere(callExp.Arguments[1]); _select.Limit(1); isfirst = true; break; @@ -177,10 +177,10 @@ namespace FreeSql.Extensions.Linq return throwCallExp(" 不支持"); case "Skip": - _select.Offset((int)(callExp.Arguments[1] as ConstantExpression)?.Value); + _select.Offset((int)callExp.Arguments[1].GetConstExprValue()); break; case "Take": - _select.Limit((int)(callExp.Arguments[1] as ConstantExpression)?.Value); + _select.Limit((int)callExp.Arguments[1].GetConstExprValue()); break; case "ToList": @@ -188,7 +188,6 @@ namespace FreeSql.Extensions.Linq return (TResult)(object)_select.ToList(); return throwCallExp(" 不支持"); - case "Select": var selectParam = (callExp.Arguments[1] as UnaryExpression)?.Operand as LambdaExpression; if (selectParam.Parameters.Count == 1) diff --git a/FreeSql.Tests/FreeSql.Tests/Queryable/ExprHelperTest.cs b/FreeSql.Tests/FreeSql.Tests/Queryable/ExprHelperTest.cs new file mode 100644 index 00000000..a302b81f --- /dev/null +++ b/FreeSql.Tests/FreeSql.Tests/Queryable/ExprHelperTest.cs @@ -0,0 +1,63 @@ +using FreeSql.DataAnnotations; +using FreeSql; +using System; +using System.Collections.Generic; +using Xunit; +using System.Linq; +using Newtonsoft.Json.Linq; +using NpgsqlTypes; +using Npgsql.LegacyPostgis; +using System.Linq.Expressions; +using System.Threading.Tasks; +using System.ComponentModel.DataAnnotations; +using System.Threading; +using System.Data.SqlClient; +using kwlib; +using System.Diagnostics; +using System.IO; +using System.Text; +using FreeSql.Extensions.Linq; + +namespace FreeSql.Tests.Linq +{ + public class ExprHelperTest + { + + [Fact] + public void GetConstExprValue() + { + Assert.Equal(-1, ExprHelper.GetConstExprValue(Expression.Constant(-1))); + Assert.Equal(-2, ExprHelper.GetConstExprValue(Expression.Constant(-2))); + Assert.Equal(0, ExprHelper.GetConstExprValue(Expression.Constant(0))); + Assert.Equal(1, ExprHelper.GetConstExprValue(Expression.Constant(1))); + Assert.Equal(2, ExprHelper.GetConstExprValue(Expression.Constant(2))); + + var arr = new[] { -1, -2, 0, 1, 2 }; + for (var a = 0; a < arr.Length; a++) + { + Assert.Equal(arr[a], ExprHelper.GetConstExprValue(Expression.Constant(arr[a]))); + } + + var arritems = new[] + { + new ArrItem { Prop = -1, Field = -1 }, + new ArrItem { Prop = -2, Field = -2 }, + new ArrItem { Prop = 0, Field = 0 }, + new ArrItem { Prop = 1, Field = 1 }, + new ArrItem { Prop = 2, Field = 2 }, + }; + for (var a = 0; a < arr.Length; a++) + { + Assert.Equal(arritems[a].Prop, ExprHelper.GetConstExprValue(Expression.Constant(arritems[a].Prop))); + Assert.Equal(arritems[a].Field, ExprHelper.GetConstExprValue(Expression.Constant(arritems[a].Field))); + } + } + + class ArrItem + { + public int Prop { get; set; } + public int Field { get; set; } + } + } + +}