fix: 参数值为原始DbParameter时转换类型报错

This commit is contained in:
ly303550688 2021-12-23 00:02:35 +08:00
parent a979f96bff
commit 2737175196
2 changed files with 60 additions and 8 deletions

View File

@ -0,0 +1,50 @@
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Data.SqlClient;
using Xunit;
namespace FreeSql.Tests.Internal
{
public class UtilsTest
{
[Fact]
public void TestGetDbParamtersByObject()
{
var ps = FreeSql.Internal.Utils.
GetDbParamtersByObject<DbParameter>("select @p",
new { p = (DbParameter)new SqlParameter() { ParameterName = "p", Value = "test" } },
"@",
(name, type, value) =>
{
if (value?.Equals(DateTime.MinValue) == true) value = new DateTime(1970, 1, 1);
var ret = new SqlParameter { ParameterName = $"@{name}", Value = value };
return ret;
});
Assert.Single(ps);
Assert.Equal("test", ps[0].Value);
Assert.Equal("p", ps[0].ParameterName);
Assert.Equal(typeof(SqlParameter), ps[0].GetType());
var ps2 = FreeSql.Internal.Utils.
GetDbParamtersByObject<DbParameter>("select @p",
new Dictionary<string, DbParameter> { { "p", (DbParameter)new SqlParameter() { ParameterName = "p", Value = "test" } } },
"@",
(name, type, value) =>
{
if (value?.Equals(DateTime.MinValue) == true) value = new DateTime(1970, 1, 1);
var ret = new SqlParameter { ParameterName = $"@{name}", Value = value };
return ret;
});
Assert.Single(ps2);
Assert.Equal("test", ps2[0].Value);
Assert.Equal("p", ps2[0].ParameterName);
Assert.Equal(typeof(SqlParameter), ps2[0].GetType());
}
}
}

View File

@ -4,6 +4,7 @@ using System;
using System.Collections; using System.Collections;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Data;
using System.Data.Common; using System.Data.Common;
using System.Linq; using System.Linq;
using System.Linq.Expressions; using System.Linq.Expressions;
@ -118,7 +119,7 @@ namespace FreeSql.Internal
var leftBt = colattr.DbType.IndexOf('('); var leftBt = colattr.DbType.IndexOf('(');
colattr.DbType = colattr.DbType.Substring(0, leftBt).ToUpper() + colattr.DbType.Substring(leftBt); colattr.DbType = colattr.DbType.Substring(0, leftBt).ToUpper() + colattr.DbType.Substring(leftBt);
} }
else if(common._orm.Ado.DataType != DataType.ClickHouse) else if (common._orm.Ado.DataType != DataType.ClickHouse)
colattr.DbType = colattr.DbType.ToUpper(); colattr.DbType = colattr.DbType.ToUpper();
if (colattrIsNull == false && colattrIsNullable == true) colattr.DbType = colattr.DbType.Replace("NOT NULL", ""); if (colattrIsNull == false && colattrIsNullable == true) colattr.DbType = colattr.DbType.Replace("NOT NULL", "");
@ -135,7 +136,7 @@ namespace FreeSql.Internal
colattr.DbType = Regex.Replace(colattr.DbType, @"\bNULL\b", "").Trim() + " NOT NULL"; colattr.DbType = Regex.Replace(colattr.DbType, @"\bNULL\b", "").Trim() + " NOT NULL";
} }
if (colattr.IsNullable == true && colattr.DbType.Contains("NOT NULL")) colattr.DbType = colattr.DbType.Replace("NOT NULL", ""); if (colattr.IsNullable == true && colattr.DbType.Contains("NOT NULL")) colattr.DbType = colattr.DbType.Replace("NOT NULL", "");
else if (colattr.IsNullable == true && !colattr.DbType.Contains("Nullable") && common._orm.Ado.DataType == DataType.ClickHouse)colattr.DbType = $"Nullable({colattr.DbType})" ; else if (colattr.IsNullable == true && !colattr.DbType.Contains("Nullable") && common._orm.Ado.DataType == DataType.ClickHouse) colattr.DbType = $"Nullable({colattr.DbType})";
colattr.DbType = Regex.Replace(colattr.DbType, @"\([^\)]+\)", m => colattr.DbType = Regex.Replace(colattr.DbType, @"\([^\)]+\)", m =>
{ {
var tmpLt = Regex.Replace(m.Groups[0].Value, @"\s", ""); var tmpLt = Regex.Replace(m.Groups[0].Value, @"\s", "");
@ -1272,12 +1273,13 @@ namespace FreeSql.Internal
}); });
public static T[] GetDbParamtersByObject<T>(string sql, object obj, string paramPrefix, Func<string, Type, object, T> constructorParamter) public static T[] GetDbParamtersByObject<T>(string sql, object obj, string paramPrefix, Func<string, Type, object, T> constructorParamter)
where T : IDataParameter
{ {
if (string.IsNullOrEmpty(sql) || obj == null) return new T[0]; if (string.IsNullOrEmpty(sql) || obj == null) return new T[0];
var isCheckSql = sql != "*"; var isCheckSql = sql != "*";
var ttype = typeof(T); var ttype = typeof(T);
var type = obj.GetType(); var type = obj.GetType();
if (type == ttype) return new[] { (T)Convert.ChangeType(obj, type) }; if (ttype.IsAssignableFrom(type)) return new[] { (T)obj };
var ret = new List<T>(); var ret = new List<T>();
var dic = obj as IDictionary; var dic = obj as IDictionary;
if (dic != null) if (dic != null)
@ -1288,7 +1290,7 @@ namespace FreeSql.Internal
if (isCheckSql && string.IsNullOrEmpty(paramPrefix) == false && sql.IndexOf($"{paramPrefix}{dbkey}", StringComparison.CurrentCultureIgnoreCase) == -1) continue; if (isCheckSql && string.IsNullOrEmpty(paramPrefix) == false && sql.IndexOf($"{paramPrefix}{dbkey}", StringComparison.CurrentCultureIgnoreCase) == -1) continue;
var val = dic[key]; var val = dic[key];
var valType = val == null ? typeof(string) : val.GetType(); var valType = val == null ? typeof(string) : val.GetType();
if (valType == ttype) ret.Add((T)Convert.ChangeType(val, ttype)); if (ttype.IsAssignableFrom(valType)) ret.Add((T)val);
else ret.Add(constructorParamter(dbkey, valType, val)); else ret.Add(constructorParamter(dbkey, valType, val));
} }
} }
@ -1299,7 +1301,7 @@ namespace FreeSql.Internal
{ {
if (isCheckSql && string.IsNullOrEmpty(paramPrefix) == false && sql.IndexOf($"{paramPrefix}{p.Name}", StringComparison.CurrentCultureIgnoreCase) == -1) continue; if (isCheckSql && string.IsNullOrEmpty(paramPrefix) == false && sql.IndexOf($"{paramPrefix}{p.Name}", StringComparison.CurrentCultureIgnoreCase) == -1) continue;
var pvalue = p.GetValue(obj, null); var pvalue = p.GetValue(obj, null);
if (p.PropertyType == ttype) ret.Add((T)Convert.ChangeType(pvalue, ttype)); if (ttype.IsAssignableFrom(p.PropertyType)) ret.Add((T)pvalue);
else ret.Add(constructorParamter(p.Name, p.PropertyType, pvalue)); else ret.Add(constructorParamter(p.Name, p.PropertyType, pvalue));
} }
} }
@ -1371,7 +1373,7 @@ namespace FreeSql.Internal
this.Value = value; this.Value = value;
this.DataIndex = dataIndex; this.DataIndex = dataIndex;
} }
public static ConstructorInfo Constructor = typeof(RowInfo). GetConstructor(new[] { typeof(object), typeof(int) }); public static ConstructorInfo Constructor = typeof(RowInfo).GetConstructor(new[] { typeof(object), typeof(int) });
public static PropertyInfo PropertyValue = typeof(RowInfo).GetProperty("Value"); public static PropertyInfo PropertyValue = typeof(RowInfo).GetProperty("Value");
public static PropertyInfo PropertyDataIndex = typeof(RowInfo).GetProperty("DataIndex"); public static PropertyInfo PropertyDataIndex = typeof(RowInfo).GetProperty("DataIndex");
} }
@ -1851,7 +1853,7 @@ namespace FreeSql.Internal
static MethodInfo MethodBigIntegerParse = typeof(Utils).GetMethod("ToBigInteger", BindingFlags.Public | BindingFlags.Static, null, new[] { typeof(string) }, null); static MethodInfo MethodBigIntegerParse = typeof(Utils).GetMethod("ToBigInteger", BindingFlags.Public | BindingFlags.Static, null, new[] { typeof(string) }, null);
static PropertyInfo PropertyDateTimeOffsetDateTime = typeof(DateTimeOffset).GetProperty("DateTime", BindingFlags.Instance | BindingFlags.Public); static PropertyInfo PropertyDateTimeOffsetDateTime = typeof(DateTimeOffset).GetProperty("DateTime", BindingFlags.Instance | BindingFlags.Public);
static PropertyInfo PropertyDateTimeTicks = typeof(DateTime).GetProperty("Ticks", BindingFlags.Instance | BindingFlags.Public); static PropertyInfo PropertyDateTimeTicks = typeof(DateTime).GetProperty("Ticks", BindingFlags.Instance | BindingFlags.Public);
static ConstructorInfo CtorDateTimeOffsetArgsTicks = typeof(DateTimeOffset). GetConstructor(new[] { typeof(long), typeof(TimeSpan) }); static ConstructorInfo CtorDateTimeOffsetArgsTicks = typeof(DateTimeOffset).GetConstructor(new[] { typeof(long), typeof(TimeSpan) });
static Encoding DefaultEncoding = Encoding.UTF8; static Encoding DefaultEncoding = Encoding.UTF8;
static MethodInfo MethodEncodingGetBytes = typeof(Encoding).GetMethod("GetBytes", new[] { typeof(string) }); static MethodInfo MethodEncodingGetBytes = typeof(Encoding).GetMethod("GetBytes", new[] { typeof(string) });
static MethodInfo MethodEncodingGetString = typeof(Encoding).GetMethod("GetString", new[] { typeof(byte[]) }); static MethodInfo MethodEncodingGetString = typeof(Encoding).GetMethod("GetString", new[] { typeof(byte[]) });