diff --git a/FreeSql.Tests/FreeSql.Tests/Internal/UtilsTest.cs b/FreeSql.Tests/FreeSql.Tests/Internal/UtilsTest.cs new file mode 100644 index 00000000..250bae5c --- /dev/null +++ b/FreeSql.Tests/FreeSql.Tests/Internal/UtilsTest.cs @@ -0,0 +1,50 @@ +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Xunit; + +namespace FreeSql.Tests.Internal +{ + + public class UtilsTest + { + [Fact] + public void TestGetDbParamtersByObject() + { + var ps = FreeSql.Internal.Utils. + GetDbParamtersByObject("select @p", + new { p = (DbParameter)new SqlParameter() { ParameterName = "p", Value = "test" } }, + "@", + (name, type, value) => + { + if (value?.Equals(DateTime.MinValue) == true) value = new DateTime(1970, 1, 1); + var ret = new SqlParameter { ParameterName = $"@{name}", Value = value }; + return ret; + }); + Assert.Single(ps); + Assert.Equal("test", ps[0].Value); + Assert.Equal("p", ps[0].ParameterName); + Assert.Equal(typeof(SqlParameter), ps[0].GetType()); + + + var ps2 = FreeSql.Internal.Utils. + GetDbParamtersByObject("select @p", + new Dictionary { { "p", (DbParameter)new SqlParameter() { ParameterName = "p", Value = "test" } } }, + "@", + (name, type, value) => + { + if (value?.Equals(DateTime.MinValue) == true) value = new DateTime(1970, 1, 1); + var ret = new SqlParameter { ParameterName = $"@{name}", Value = value }; + return ret; + }); + Assert.Single(ps2); + Assert.Equal("test", ps2[0].Value); + Assert.Equal("p", ps2[0].ParameterName); + Assert.Equal(typeof(SqlParameter), ps2[0].GetType()); + } + } +} diff --git a/FreeSql/Internal/UtilsExpressionTree.cs b/FreeSql/Internal/UtilsExpressionTree.cs index a4b43ca4..c0b0422f 100644 --- a/FreeSql/Internal/UtilsExpressionTree.cs +++ b/FreeSql/Internal/UtilsExpressionTree.cs @@ -4,6 +4,7 @@ using System; using System.Collections; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Data; using System.Data.Common; using System.Linq; using System.Linq.Expressions; @@ -118,7 +119,7 @@ namespace FreeSql.Internal var leftBt = colattr.DbType.IndexOf('('); colattr.DbType = colattr.DbType.Substring(0, leftBt).ToUpper() + colattr.DbType.Substring(leftBt); } - else if(common._orm.Ado.DataType != DataType.ClickHouse) + else if (common._orm.Ado.DataType != DataType.ClickHouse) colattr.DbType = colattr.DbType.ToUpper(); if (colattrIsNull == false && colattrIsNullable == true) colattr.DbType = colattr.DbType.Replace("NOT NULL", ""); @@ -135,7 +136,7 @@ namespace FreeSql.Internal colattr.DbType = Regex.Replace(colattr.DbType, @"\bNULL\b", "").Trim() + " NOT NULL"; } if (colattr.IsNullable == true && colattr.DbType.Contains("NOT NULL")) colattr.DbType = colattr.DbType.Replace("NOT NULL", ""); - else if (colattr.IsNullable == true && !colattr.DbType.Contains("Nullable") && common._orm.Ado.DataType == DataType.ClickHouse)colattr.DbType = $"Nullable({colattr.DbType})" ; + else if (colattr.IsNullable == true && !colattr.DbType.Contains("Nullable") && common._orm.Ado.DataType == DataType.ClickHouse) colattr.DbType = $"Nullable({colattr.DbType})"; colattr.DbType = Regex.Replace(colattr.DbType, @"\([^\)]+\)", m => { var tmpLt = Regex.Replace(m.Groups[0].Value, @"\s", ""); @@ -1272,12 +1273,13 @@ namespace FreeSql.Internal }); public static T[] GetDbParamtersByObject(string sql, object obj, string paramPrefix, Func constructorParamter) + where T : IDataParameter { if (string.IsNullOrEmpty(sql) || obj == null) return new T[0]; var isCheckSql = sql != "*"; var ttype = typeof(T); var type = obj.GetType(); - if (type == ttype) return new[] { (T)Convert.ChangeType(obj, type) }; + if (ttype.IsAssignableFrom(type)) return new[] { (T)obj }; var ret = new List(); var dic = obj as IDictionary; if (dic != null) @@ -1288,7 +1290,7 @@ namespace FreeSql.Internal if (isCheckSql && string.IsNullOrEmpty(paramPrefix) == false && sql.IndexOf($"{paramPrefix}{dbkey}", StringComparison.CurrentCultureIgnoreCase) == -1) continue; var val = dic[key]; var valType = val == null ? typeof(string) : val.GetType(); - if (valType == ttype) ret.Add((T)Convert.ChangeType(val, ttype)); + if (ttype.IsAssignableFrom(valType)) ret.Add((T)val); else ret.Add(constructorParamter(dbkey, valType, val)); } } @@ -1299,7 +1301,7 @@ namespace FreeSql.Internal { if (isCheckSql && string.IsNullOrEmpty(paramPrefix) == false && sql.IndexOf($"{paramPrefix}{p.Name}", StringComparison.CurrentCultureIgnoreCase) == -1) continue; var pvalue = p.GetValue(obj, null); - if (p.PropertyType == ttype) ret.Add((T)Convert.ChangeType(pvalue, ttype)); + if (ttype.IsAssignableFrom(p.PropertyType)) ret.Add((T)pvalue); else ret.Add(constructorParamter(p.Name, p.PropertyType, pvalue)); } } @@ -1371,7 +1373,7 @@ namespace FreeSql.Internal this.Value = value; this.DataIndex = dataIndex; } - public static ConstructorInfo Constructor = typeof(RowInfo). GetConstructor(new[] { typeof(object), typeof(int) }); + public static ConstructorInfo Constructor = typeof(RowInfo).GetConstructor(new[] { typeof(object), typeof(int) }); public static PropertyInfo PropertyValue = typeof(RowInfo).GetProperty("Value"); public static PropertyInfo PropertyDataIndex = typeof(RowInfo).GetProperty("DataIndex"); } @@ -1851,7 +1853,7 @@ namespace FreeSql.Internal static MethodInfo MethodBigIntegerParse = typeof(Utils).GetMethod("ToBigInteger", BindingFlags.Public | BindingFlags.Static, null, new[] { typeof(string) }, null); static PropertyInfo PropertyDateTimeOffsetDateTime = typeof(DateTimeOffset).GetProperty("DateTime", BindingFlags.Instance | BindingFlags.Public); static PropertyInfo PropertyDateTimeTicks = typeof(DateTime).GetProperty("Ticks", BindingFlags.Instance | BindingFlags.Public); - static ConstructorInfo CtorDateTimeOffsetArgsTicks = typeof(DateTimeOffset). GetConstructor(new[] { typeof(long), typeof(TimeSpan) }); + static ConstructorInfo CtorDateTimeOffsetArgsTicks = typeof(DateTimeOffset).GetConstructor(new[] { typeof(long), typeof(TimeSpan) }); static Encoding DefaultEncoding = Encoding.UTF8; static MethodInfo MethodEncodingGetBytes = typeof(Encoding).GetMethod("GetBytes", new[] { typeof(string) }); static MethodInfo MethodEncodingGetString = typeof(Encoding).GetMethod("GetString", new[] { typeof(byte[]) }); @@ -1870,7 +1872,7 @@ namespace FreeSql.Internal { if (type.IsArray) { - switch (type.FullName) + switch (type.FullName) { case "System.Byte[]": return Expression.IfThenElse(