Skip to content

Commit

Permalink
FIX | Removing BinaryFormatter from NetFx (dotnet#869)
Browse files Browse the repository at this point in the history
* Removing BinaryFormatter from NetFx

* review comments

* fix version typo

* remove extra line

* Reverted SqlException Test

* review comments

* Review comment

* Desrialize

* addressing review comments

* Fix exception in deserialization (#1)

* review comments

* add extra line to the end of strings designer

* end of line

Co-authored-by: jJRahnama <jrahnama@simba.com>
Co-authored-by: Karina Zhou <v-jizho2@microsoft.com>
  • Loading branch information
3 people committed Feb 18, 2021
1 parent b5d7bb6 commit 25cde90
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 49 deletions.
Expand Up @@ -10,7 +10,6 @@
using System.Runtime.CompilerServices;
using System.Runtime.Remoting;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using System.Runtime.Versioning;
using System.Security.Permissions;
using System.Text;
Expand Down Expand Up @@ -241,29 +240,39 @@ private static void InvokeCallback(object eventContextPair)
// END EventContextPair private class.
// ----------------------------------------

// ----------------------------------------
// Private class for restricting allowed types from deserialization.
// ----------------------------------------

private class SqlDependencyProcessDispatcherSerializationBinder : SerializationBinder
//-----------------------------------------------
// Private Class to add ObjRef as DataContract
//-----------------------------------------------
[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.RemotingConfiguration)]
[DataContract]
private class SqlClientObjRef
{
public override Type BindToType(string assemblyName, string typeName)
[DataMember]
private static ObjRef s_sqlObjRef;
internal static IRemotingTypeInfo _typeInfo;

private SqlClientObjRef() { }

public SqlClientObjRef(SqlDependencyProcessDispatcher dispatcher) : base()
{
// Deserializing an unexpected type can inject objects with malicious side effects.
// If the type is unexpected, throw an exception to stop deserialization.
if (typeName == nameof(SqlDependencyProcessDispatcher))
{
return typeof(SqlDependencyProcessDispatcher);
}
else
{
throw new ArgumentException("Unexpected type", nameof(typeName));
}
s_sqlObjRef = RemotingServices.Marshal(dispatcher);
_typeInfo = s_sqlObjRef.TypeInfo;
}

internal static bool CanCastToSqlDependencyProcessDispatcher()
{
return _typeInfo.CanCastTo(typeof(SqlDependencyProcessDispatcher), s_sqlObjRef);
}

internal ObjRef GetObjRef()
{
return s_sqlObjRef;
}

}
// ----------------------------------------
// END SqlDependencyProcessDispatcherSerializationBinder private class.
// ----------------------------------------
// ------------------------------------------
// End SqlClientObjRef private class.
// -------------------------------------------

// ----------------
// Instance members
Expand Down Expand Up @@ -306,10 +315,9 @@ public override Type BindToType(string assemblyName, string typeName)
private static readonly string _typeName = (typeof(SqlDependencyProcessDispatcher)).FullName;

// -----------
// BID members
// EventSource members
// -----------


private readonly int _objectID = System.Threading.Interlocked.Increment(ref _objectTypeCount);
private static int _objectTypeCount; // EventSource Counter
internal int ObjectID
Expand All @@ -336,7 +344,7 @@ public SqlDependency(SqlCommand command) : this(command, null, SQL.SqlDependency
}

/// <include file='..\..\..\..\..\..\..\doc\snippets\Microsoft.Data.SqlClient\SqlDependency.xml' path='docs/members[@name="SqlDependency"]/ctorCommandOptionsTimeout/*' />
[System.Security.Permissions.HostProtectionAttribute(ExternalThreading = true)]
[HostProtection(ExternalThreading = true)]
public SqlDependency(SqlCommand command, string options, int timeout)
{
long scopeID = SqlClientEventSource.Log.TryNotificationScopeEnterEvent("<sc.SqlDependency|DEP> {0}, options: '{1}', timeout: '{2}'", ObjectID, options, timeout);
Expand Down Expand Up @@ -597,11 +605,13 @@ private static void ObtainProcessDispatcher()
_processDispatcher = dependency.SingletonProcessDispatcher; // Set to static instance.

// Serialize and set in native.
ObjRef objRef = GetObjRef(_processDispatcher);
BinaryFormatter formatter = new BinaryFormatter();
MemoryStream stream = new MemoryStream();
GetSerializedObject(objRef, formatter, stream);
SNINativeMethodWrapper.SetData(stream.GetBuffer()); // Native will be forced to synchronize and not overwrite.
using (MemoryStream stream = new MemoryStream())
{
SqlClientObjRef objRef = new SqlClientObjRef(_processDispatcher);
DataContractSerializer serializer = new DataContractSerializer(objRef.GetType());
GetSerializedObject(objRef, serializer, stream);
SNINativeMethodWrapper.SetData(stream.ToArray()); // Native will be forced to synchronize and not overwrite.
}
}
else
{
Expand All @@ -628,37 +638,39 @@ private static void ObtainProcessDispatcher()
#if DEBUG // Possibly expensive, limit to debug.
SqlClientEventSource.Log.TryNotificationTraceEvent("<sc.SqlDependency.ObtainProcessDispatcher|DEP> AppDomain.CurrentDomain.FriendlyName: {0}", AppDomain.CurrentDomain.FriendlyName);
#endif
BinaryFormatter formatter = new BinaryFormatter();
MemoryStream stream = new MemoryStream(nativeStorage);
_processDispatcher = GetDeserializedObject(formatter, stream); // Deserialize and set for appdomain.
SqlClientEventSource.Log.TryNotificationTraceEvent("<sc.SqlDependency.ObtainProcessDispatcher|DEP> processDispatcher obtained, ID: {0}", _processDispatcher.ObjectID);
using (MemoryStream stream = new MemoryStream(nativeStorage))
{
DataContractSerializer serializer = new DataContractSerializer(typeof(SqlClientObjRef));
if (SqlClientObjRef.CanCastToSqlDependencyProcessDispatcher())
{
// Deserialize and set for appdomain.
_processDispatcher = GetDeserializedObject(serializer, stream);
}
else
{
throw new ArgumentException(Strings.SqlDependency_UnexpectedValueOnDeserialize);
}
SqlClientEventSource.Log.TryNotificationTraceEvent("<sc.SqlDependency.ObtainProcessDispatcher|DEP> processDispatcher obtained, ID: {0}", _processDispatcher.ObjectID);
}
}
}

// ---------------------------------------------------------
// Static security asserted methods - limit scope of assert.
// ---------------------------------------------------------

[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.RemotingConfiguration)]
private static ObjRef GetObjRef(SqlDependencyProcessDispatcher _processDispatcher)
{
return RemotingServices.Marshal(_processDispatcher);
}

[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.SerializationFormatter)]
private static void GetSerializedObject(ObjRef objRef, BinaryFormatter formatter, MemoryStream stream)
private static void GetSerializedObject(SqlClientObjRef objRef, DataContractSerializer serializer, MemoryStream stream)
{
formatter.Serialize(stream, objRef);
serializer.WriteObject(stream, objRef);
}

[SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.SerializationFormatter)]
private static SqlDependencyProcessDispatcher GetDeserializedObject(BinaryFormatter formatter, MemoryStream stream)
private static SqlDependencyProcessDispatcher GetDeserializedObject(DataContractSerializer serializer, MemoryStream stream)
{
// Use a custom SerializationBinder to restrict deserialized types to SqlDependencyProcessDispatcher.
formatter.Binder = new SqlDependencyProcessDispatcherSerializationBinder();
object result = formatter.Deserialize(stream);
Debug.Assert(result.GetType() == typeof(SqlDependencyProcessDispatcher), "Unexpected type stored in native!");
return (SqlDependencyProcessDispatcher)result;
object refResult = serializer.ReadObject(stream);
var result = RemotingServices.Unmarshal((refResult as SqlClientObjRef).GetObjRef());
return result as SqlDependencyProcessDispatcher;
}

// -------------------------
Expand Down Expand Up @@ -1325,7 +1337,6 @@ private void AddCommandInternal(SqlCommand cmd)
{
if (cmd != null)
{
// Don't bother with BID if command null.
long scopeID = SqlClientEventSource.Log.TryNotificationScopeEnterEvent("<sc.SqlDependency.AddCommandInternal|DEP> {0}, SqlCommand: {1}", ObjectID, cmd.ObjectID);
try
{
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -4602,4 +4602,7 @@
<data name="Azure_RetryFailure" xml:space="preserve">
<value>Failed after 5 retries.</value>
</data>
</root>
<data name="SqlDependency_UnexpectedValueOnDeserialize" xml:space="preserve">
<value>Unexpected type detected on deserialize.</value>
</data>
</root>
Expand Up @@ -33,7 +33,6 @@ public void SerializationTest()
Assert.Equal(e.StackTrace, sqlEx.StackTrace);
}


[Fact]
[ActiveIssue("12161", TestPlatforms.AnyUnix)]
public static void SqlExcpetionSerializationTest()
Expand Down

0 comments on commit 25cde90

Please sign in to comment.