Skip to content

Commit

Permalink
More robust detection of non-ASCII input (#170)
Browse files Browse the repository at this point in the history
* Updated fuzzing

* More tests for non-ASCII input

* Fix
  • Loading branch information
gfoidl committed Sep 30, 2022
1 parent c6a4f9a commit b7b6ddf
Show file tree
Hide file tree
Showing 34 changed files with 144 additions and 122 deletions.
76 changes: 17 additions & 59 deletions fuzz/gfoidl.Base64.FuzzTests/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using System;
using System.Buffers;
using System.IO;
using System.IO.Pipelines;
using gfoidl.Base64.Internal;
using SharpFuzz;

Expand All @@ -18,8 +17,8 @@ static void Main(string[] args)
// Note: Build config "Fuzz" must be set for debugging in VS

// To verify a crash switch to Debug in the configuration manager and uncomment the following lines
//var stream = File.OpenRead(@".\findings\crashes\id%3A000001,sig%3A02,src%3A000000,op%3Ahavoc,rep%3A64");
//Base64_Url_Decode(stream);
//string base64Text = File.ReadAllText("./findings/crashes/id-000000,sig-02,src-000000,op-flip1,pos-0");
//Base64_Default_Decode(base64Text);
//return;

if (args.Length < 1)
Expand All @@ -40,65 +39,24 @@ static void Main(string[] args)
}
}
//---------------------------------------------------------------------
private static void Base64_Default_Decode(Stream stream) => Base64_Decode(stream, Base64.Default);
private static void Base64_Url_Decode (Stream stream) => Base64_Decode(stream, Base64.Url);
private static void Base64_Default_Decode(string base64Text) => Base64_Decode(base64Text, Base64.Default);
private static void Base64_Url_Decode (string base64Text) => Base64_Decode(base64Text, Base64.Url);
//---------------------------------------------------------------------
private static void Base64_Decode(Stream stream, Base64 encoder)
{
PipeReader pipeReader = PipeReader.Create(stream, new StreamPipeReaderOptions(leaveOpen: true));

while (true)
{
ReadResult readResult = pipeReader.ReadAsync().AsTask().GetAwaiter().GetResult();

if (readResult.IsCompleted || readResult.IsCanceled)
break;

ReadOnlySequence<byte> buffer = readResult.Buffer;

if (buffer.Length > int.MaxValue)
return;

if (buffer.IsSingleSegment)
{
Base64_Decode(buffer.FirstSpan, encoder);
}
else
{
byte[] arrayToReturnToPool = ArrayPool<byte>.Shared.Rent((int)buffer.Length);
try
{
buffer.CopyTo(arrayToReturnToPool);
Base64_Decode(arrayToReturnToPool.AsSpan(0, (int)buffer.Length), encoder);
}
finally
{
ArrayPool<byte>.Shared.Return(arrayToReturnToPool);
}
}

pipeReader.AdvanceTo(buffer.End);
}

pipeReader.Complete();
}
//---------------------------------------------------------------------
private static void Base64_Decode(ReadOnlySpan<byte> encoded, Base64 encoder)
private static void Base64_Decode(ReadOnlySpan<char> base64Text, Base64 encoder)
{
byte[]? dataArrayFromPool = null;
try
{
dataArrayFromPool = ArrayPool<byte>.Shared.Rent(encoder.GetMaxDecodedLength(encoded.Length));
dataArrayFromPool = ArrayPool<byte>.Shared.Rent(encoder.GetMaxDecodedLength(base64Text.Length));

OperationStatus status = encoder.Decode(encoded, dataArrayFromPool, out int consumed, out int written);
OperationStatus status = encoder.Decode(base64Text, dataArrayFromPool, out int consumed, out int written);

if (ContainsInvalidData(encoded, encoder) && status != OperationStatus.InvalidData)
if (ContainsInvalidData(base64Text, encoder) && status != OperationStatus.InvalidData)
throw new Exception("contains invalid data -- not detected");

if (status == OperationStatus.Done)
if (status == OperationStatus.Done && consumed != base64Text.Length)
{
if (consumed != encoded.Length)
throw new Exception("consumed != encoded.Length");
throw new Exception("consumed != encoded.Length");
}
}
catch (ArgumentOutOfRangeException)
Expand All @@ -118,7 +76,7 @@ private static void Base64_Decode(ReadOnlySpan<byte> encoded, Base64 encoder)
}
}
//---------------------------------------------------------------------
private static bool ContainsInvalidData(ReadOnlySpan<byte> encoded, Base64 encoder)
private static bool ContainsInvalidData(ReadOnlySpan<char> base64Text, Base64 encoder)
{
ReadOnlySpan<sbyte> decodingMap = default;

Expand All @@ -129,21 +87,21 @@ private static bool ContainsInvalidData(ReadOnlySpan<byte> encoded, Base64 encod
// Check for padding at the end
int paddingCount = 0;

if (encoded.Length > 1 && encoded[^1] == Base64Encoder.EncodingPad) paddingCount++;
if (encoded.Length > 2 && encoded[^2] == Base64Encoder.EncodingPad) paddingCount++;
if (base64Text.Length > 1 && base64Text[^1] == Base64Encoder.EncodingPad) paddingCount++;
if (base64Text.Length > 2 && base64Text[^2] == Base64Encoder.EncodingPad) paddingCount++;

encoded = encoded[0..^paddingCount];
base64Text = base64Text[0..^paddingCount];
}
else if (encoder is Base64UrlEncoder)
{
decodingMap = Base64UrlEncoder.DecodingMap;
}

for (int i = 0; i < encoded.Length; ++i)
for (int i = 0; i < base64Text.Length; ++i)
{
byte e = encoded[i];
char ch = base64Text[i];

if (decodingMap[e] == -1) return true;
if (ch > 0x7F || decodingMap[ch] == -1) return true;
}

return false;
Expand Down
3 changes: 1 addition & 2 deletions fuzz/gfoidl.Base64.FuzzTests/gfoidl.Base64.FuzzTests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="SharpFuzz" Version="1.6.2" />
<PackageReference Include="SharpFuzz" Version="2.0.0" />
<PackageReference Include="Sigil" Version="5.0.0" />
<PackageReference Include="System.IO.Pipelines" Version="5.0.1" />
</ItemGroup>

<ItemGroup>
Expand Down
3 changes: 2 additions & 1 deletion fuzz/init.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ set -e
echo "See https://github.com/Metalnem/sharpfuzz for setup"
echo ""

tfm="net5.0"
lib=gfoidl.Base64.dll

dotnet build -c Release gfoidl.Base64.FuzzTests
mkdir -p ./instrumented
cp ../source/gfoidl.Base64/bin/Release/net5.0/$lib ./instrumented/$lib
cp ../source/gfoidl.Base64/bin/Release/$tfm/$lib ./instrumented/$lib

sharpfuzz ./instrumented/$lib

Expand Down
1 change: 1 addition & 0 deletions fuzz/testcases/2.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
z/TpH7sqEkerqMwelH1uSw=
1 change: 1 addition & 0 deletions fuzz/testcases/3.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
z/TpH7sqEkerqMwelH1uSw🌄
5 changes: 3 additions & 2 deletions gfoidl.Base64.sln
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 16
VisualStudioVersion = 16.0.28803.452
# Visual Studio Version 17
VisualStudioVersion = 17.3.32922.545
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{403238BD-BE8A-4E78-924A-C841CD8C7955}"
ProjectSection(SolutionItems) = preProject
Expand Down Expand Up @@ -148,6 +148,7 @@ Global
{41097931-26CD-4203-B69C-A14CA5FBB69D}.Release|Any CPU.ActiveCfg = Release|Any CPU
{41097931-26CD-4203-B69C-A14CA5FBB69D}.Release|Any CPU.Build.0 = Release|Any CPU
{DC515055-B8E7-4C25-B4CF-703292C4F975}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{DC515055-B8E7-4C25-B4CF-703292C4F975}.Debug|Any CPU.Build.0 = Debug|Any CPU
{DC515055-B8E7-4C25-B4CF-703292C4F975}.Fuzz|Any CPU.ActiveCfg = Release|Any CPU
{DC515055-B8E7-4C25-B4CF-703292C4F975}.Fuzz|Any CPU.Build.0 = Release|Any CPU
{DC515055-B8E7-4C25-B4CF-703292C4F975}.Release|Any CPU.ActiveCfg = Release|Any CPU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ public partial class Base64Encoder
t1 = Unsafe.Add(ref tmp, (IntPtr)(uint)(srcLength - 3));
t2 = Unsafe.Add(ref tmp, (IntPtr)(uint)(srcLength - 2));
t3 = Unsafe.Add(ref tmp, (IntPtr)(uint)(srcLength - 1));

// Need to sanitize for non-ASCII input
if (((t0 | t1) | (t2 | t3)) > 256)
{
t0 = t1 = t2 = t3 = 0;
}
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ public partial class Base64Encoder
t1 = Unsafe.Add(ref tmp, (IntPtr)(uint)(srcLength - 3));
t2 = Unsafe.Add(ref tmp, (IntPtr)(uint)(srcLength - 2));
t3 = Unsafe.Add(ref tmp, (IntPtr)(uint)(srcLength - 1));

// Need to sanitize for non-ASCII input
if (((t0 | t1) | (t2 | t3)) > 256)
{
t0 = t1 = t2 = t3 = 0;
}
}
else
{
Expand Down
6 changes: 6 additions & 0 deletions source/gfoidl.Base64/Internal/Base64EncoderImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ protected static int DecodeFour<T>(ref T encoded, ref sbyte decodingMap)
t1 = Unsafe.Add(ref tmp, 1);
t2 = Unsafe.Add(ref tmp, 2);
t3 = Unsafe.Add(ref tmp, 3);

// Need to sanitize for non-ASCII input
if (((t0 | t1) | (t2 | t3)) > 256)
{
t0 = t1 = t2 = t3 = 0;
}
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ private static int DecodeThree<T>(ref T encoded, ref sbyte decodingMap)
t0 = Unsafe.Add(ref tmp, 0);
t1 = Unsafe.Add(ref tmp, 1);
t2 = Unsafe.Add(ref tmp, 2);

// Need to sanitize for non-ASCII input
if (((t0 | t1) | t2) > 256)
{
t0 = t1 = t2 = 0;
}
}
else
{
Expand Down Expand Up @@ -180,6 +186,12 @@ private static int DecodeTwo<T>(ref T encoded, ref sbyte decodingMap)
ref char tmp = ref Unsafe.As<T, char>(ref encoded);
t0 = Unsafe.Add(ref tmp, 0);
t1 = Unsafe.Add(ref tmp, 1);

// Need to sanitize for non-ASCII input
if ((t0 | t1) > 256)
{
t0 = t1 = 0;
}
}
else
{
Expand Down
4 changes: 2 additions & 2 deletions source/gfoidl.Base64/gfoidl.Base64.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ For .NET Core 3.0 onwards encoding / decoding is done with SIMD-support.

<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0' OR '$(TargetFramework)' == 'net45'">
<PackageReference Include="System.Buffers" Version="4.5.1" />
<PackageReference Include="System.Memory" Version="4.5.4" />
<PackageReference Include="System.Memory" Version="4.5.5" />
<PackageReference Include="System.Runtime.CompilerServices.Unsafe" Version="5.0.0" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.1'">
<PackageReference Include="System.Runtime.CompilerServices.Unsafe" Version="4.6.0" />
<PackageReference Include="System.Runtime.CompilerServices.Unsafe" Version="6.0.0" />
</ItemGroup>

<ItemGroup>
Expand Down
6 changes: 3 additions & 3 deletions tests/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.9.4" />
<PackageReference Include="NUnit" Version="3.13.1" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" />
<PackageReference Include="NUnit" Version="3.13.3" />
<PackageReference Include="NUnit3TestAdapter" Version="3.17.0">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="coverlet.msbuild" Version="3.0.3">
<PackageReference Include="coverlet.msbuild" Version="3.1.2">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void Basic_decoding_with_known_input___Done(string input, int expectedWri

if (typeof(T) == typeof(byte))
{
ReadOnlySpan<byte> tmp = Encoding.ASCII.GetBytes(input);
ReadOnlySpan<byte> tmp = Encoding.UTF8.GetBytes(input);
encoded = MemoryMarshal.Cast<byte, T>(tmp);
}
else if (typeof(T) == typeof(char))
Expand Down Expand Up @@ -64,7 +64,7 @@ public void Basic_decoding_with_known_input_isFinalBlock_false___Done(string inp

if (typeof(T) == typeof(byte))
{
ReadOnlySpan<byte> tmp = Encoding.ASCII.GetBytes(input);
ReadOnlySpan<byte> tmp = Encoding.UTF8.GetBytes(input);
encoded = MemoryMarshal.Cast<byte, T>(tmp);
}
else if (typeof(T) == typeof(char))
Expand Down Expand Up @@ -96,16 +96,20 @@ public void Basic_decoding_with_known_input_isFinalBlock_false___Done(string inp
[TestCase("A", 0, 0)]
[TestCase("AQ", 0, 0)]
[TestCase("AQI", 0, 0)]
[TestCase("🌄34", 0, 0)] // length is 4
[TestCase("AQIDBA", 4, 3)]
[TestCase("AQIDBAU", 4, 3)]
[TestCase("ìz/TpH7sqEkerqMweH1uSw==", 0, 0)] // note the ì character
[TestCase("z/TpH7sqEkerqMweìH1uSw==", 16, 12)]
[TestCase("🌄a/TpH7sqEkerqMwelH1uSw", 0, 0)]
public void Basic_decoding_with_invalid_input___InvalidData(string input, int expectedConsumed, int expectedWritten)
{
var sut = new Base64Encoder();
ReadOnlySpan<T> encoded;

if (typeof(T) == typeof(byte))
{
ReadOnlySpan<byte> tmp = Encoding.ASCII.GetBytes(input);
ReadOnlySpan<byte> tmp = Encoding.UTF8.GetBytes(input);
encoded = MemoryMarshal.Cast<byte, T>(tmp);
}
else if (typeof(T) == typeof(char))
Expand Down Expand Up @@ -147,7 +151,7 @@ public void Basic_decoding_with_invalid_input_isFinalBlock_false___NeedMoreData(

if (typeof(T) == typeof(byte))
{
ReadOnlySpan<byte> tmp = Encoding.ASCII.GetBytes(input);
ReadOnlySpan<byte> tmp = Encoding.UTF8.GetBytes(input);
encoded = MemoryMarshal.Cast<byte, T>(tmp);
}
else if (typeof(T) == typeof(char))
Expand Down Expand Up @@ -225,7 +229,7 @@ public void Basic_decoding_with_padding_at_end_isFinalBlock_false___InvalidData(

if (typeof(T) == typeof(byte))
{
ReadOnlySpan<byte> tmp = Encoding.ASCII.GetBytes(input);
ReadOnlySpan<byte> tmp = Encoding.UTF8.GetBytes(input);
encoded = MemoryMarshal.Cast<byte, T>(tmp);
}
else if (typeof(T) == typeof(char))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public void Malformed_input___status_InvalidData(string input, bool isFinalBlock

if (typeof(T) == typeof(byte))
{
ReadOnlySpan<byte> tmp = Encoding.ASCII.GetBytes(input);
ReadOnlySpan<byte> tmp = Encoding.UTF8.GetBytes(input);
encoded = MemoryMarshal.Cast<byte, T>(tmp);
}
else if (typeof(T) == typeof(char))
Expand Down Expand Up @@ -287,7 +287,7 @@ public void Padding_can_only_be_last_two___InvalidData(bool isFinalBlock)

if (typeof(T) == typeof(byte))
{
ReadOnlySpan<byte> tmp = Encoding.ASCII.GetBytes(encodedString);
ReadOnlySpan<byte> tmp = Encoding.UTF8.GetBytes(encodedString);
encoded = MemoryMarshal.Cast<byte, T>(tmp);
}
else if (typeof(T) == typeof(char))
Expand Down Expand Up @@ -324,7 +324,7 @@ public void Padding_can_only_be_last_two___InvalidData(bool isFinalBlock)

if (typeof(T) == typeof(byte))
{
ReadOnlySpan<byte> tmp = Encoding.ASCII.GetBytes(encodedString);
ReadOnlySpan<byte> tmp = Encoding.UTF8.GetBytes(encodedString);
encoded = MemoryMarshal.Cast<byte, T>(tmp);
}
else if (typeof(T) == typeof(char))
Expand Down

0 comments on commit b7b6ddf

Please sign in to comment.