Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test | Fix Unit Tests for GetSqlServerSPN #2442

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
case DataSource.Protocol.TCP:
sniHandle = CreateTcpHandle(details, timeout, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo,
tlsFirst, hostNameInCertificate, serverCertificateFilename);
break;
break;
case DataSource.Protocol.NP:
sniHandle = CreateNpHandle(details, timeout, parallel, tlsFirst);
break;
Expand Down Expand Up @@ -652,6 +652,11 @@ private bool InferConnectionDetails()
}

Port = port;

arellegue marked this conversation as resolved.
Show resolved Hide resolved
if (InstanceName == null && backSlashIndex > -1 && tokensByCommaAndSlash.Length == 3)
{
InstanceName = tokensByCommaAndSlash[1].Trim();
}
}
// Instance Name Handling. Only if we found a '\' and we did not find a port in the Data Source
else if (backSlashIndex > -1)
Expand Down Expand Up @@ -694,7 +699,7 @@ private bool InferNamedPipesInformation()
if (!_dataSourceAfterTrimmingProtocol.Contains(PipeBeginning))
{
// Assuming that user did not change default NamedPipe name, if the datasource is in the format servername\instance,
// separate servername and instance and prepend instance with MSSQL$ and append default pipe path
// separate server name and instance and prepend instance with MSSQL$ and append default pipe path
// https://learn.microsoft.com/en-us/sql/tools/configuration-manager/named-pipes-properties?view=sql-server-ver16
if (_dataSourceAfterTrimmingProtocol.Contains(PathSeparator) && _connectionProtocol == Protocol.NP)
{
Expand All @@ -719,6 +724,10 @@ private bool InferNamedPipesInformation()
}

InferLocalServerName();

if (InstanceName == null)
InstanceName = GetInstanceNameFromDataSource();

return true;
}

Expand Down Expand Up @@ -800,5 +809,17 @@ private bool InferNamedPipesInformation()

private static bool IsLocalHost(string serverName)
=> ".".Equals(serverName) || "(local)".Equals(serverName) || "localhost".Equals(serverName);

private string GetInstanceNameFromDataSource()
{
string instanceName = string.Empty;
string[] tokensByBackSlash = _dataSourceAfterTrimmingProtocol.Split(BackSlashCharacter);
if (tokensByBackSlash.Length > 1)
{
instanceName = tokensByBackSlash[1];
arellegue marked this conversation as resolved.
Show resolved Hide resolved
arellegue marked this conversation as resolved.
Show resolved Hide resolved
}

return instanceName;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Data.Common;
using System.Net;
using System.Net.Sockets;
using System.Reflection;
Expand Down Expand Up @@ -87,9 +88,8 @@ public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailove
}

#if NETCOREAPP
[ActiveIssue("27824")] // When specifying instance name and port number, this method call always returns false
[ConditionalFact(nameof(IsSPNPortNumberTestForTCP))]
public static void PortNumberInSPNTestForTCP()
public static void SPNTestForTCPMustReturnPortNumber()
{
string connectionString = DataTestUtility.TCPConnectionString;
SqlConnectionStringBuilder builder = new(connectionString);
Expand All @@ -98,11 +98,23 @@ public static void PortNumberInSPNTestForTCP()
Assert.True(port > 0, "Named instance must have a valid port number.");
builder.DataSource = $"{builder.DataSource},{port}";

PortNumberInSPNTest(builder.ConnectionString, port);
PortNumberInSPNTest(connectionString: builder.ConnectionString, expectedPortNumber: port);
}

[ConditionalFact(nameof(IsSPNPortNumberTestForNP))]
public static void SPNTestForNPMustReturnNamedInstance()
{
string connectionString = DataTestUtility.NPConnectionString;
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
SqlConnectionStringBuilder builder = new(connectionString);

DataTestUtility.ParseDataSource(builder.DataSource, out _, out _, out string instanceName);
arellegue marked this conversation as resolved.
Show resolved Hide resolved

Assert.True(!string.IsNullOrEmpty(instanceName), "Instance name must be included in data source.");
PortNumberInSPNTest(connectionString: builder.ConnectionString, expectedInstanceName: instanceName.ToUpper());
}
#endif

private static void PortNumberInSPNTest(string connectionString, int expectedPortNumber)
private static void PortNumberInSPNTest(string connectionString, int expectedPortNumber = 0, string expectedInstanceName = null)
{
if (DataTestUtility.IsIntegratedSecuritySetup())
{
Expand All @@ -125,15 +137,22 @@ private static void PortNumberInSPNTest(string connectionString, int expectedPor
connection.Open();

string spnInfo = GetSPNInfo(builder.DataSource);
Assert.Matches(@"MSSQLSvc\/.*:[\d]", spnInfo);

string[] spnStrs = spnInfo.Split(':');
int portInSPN = 0;
if (spnStrs.Length > 1)
if (expectedPortNumber > 0)
{
Assert.Matches(@"MSSQLSvc\/.*:[\d]", spnInfo);
string[] spnStrs = spnInfo.Split(':');
int portInSPN = 0;
if (spnStrs.Length > 1)
{
int.TryParse(spnStrs[1], out portInSPN);
}
Assert.Equal(expectedPortNumber, portInSPN);
}
else
{
int.TryParse(spnStrs[1], out portInSPN);
string[] spnStrs = spnInfo.Split(':');
Assert.Equal(expectedInstanceName, spnStrs[1].ToUpper());
}
Assert.Equal(expectedPortNumber, portInSPN);
}
}

Expand Down Expand Up @@ -180,7 +199,7 @@ private static string GetSPNInfo(string dataSource)
string serverName = serverInfo.GetValue(dataSrcInfo, null).ToString();

PropertyInfo instanceNameInfo = dataSrcInfo.GetType().GetProperty("InstanceName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
string instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString();
string instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString().ToUpper();

object port = getPortByInstanceNameInfo.Invoke(ssrpObj, parameters: new object[] { serverName, instanceName, timeoutTimerObj, false, 0 });

Expand All @@ -205,6 +224,13 @@ private static bool IsSPNPortNumberTestForTCP()
&& DataTestUtility.IsNotAzureSynapse());
}

private static bool IsSPNPortNumberTestForNP()
{
return (IsInstanceNameValid(DataTestUtility.NPConnectionString)
&& DataTestUtility.IsUsingManagedSNI()
&& DataTestUtility.IsNotAzureServer()
&& DataTestUtility.IsNotAzureSynapse());
}
private static bool IsInstanceNameValid(string connectionString)
{
string instanceName = "";
Expand Down