Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise multiplication #295

Merged
merged 4 commits into from Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions benches/bigint.rs
Expand Up @@ -87,6 +87,16 @@ fn multiply_3(b: &mut Bencher) {
multiply_bench(b, 1 << 16, 1 << 17);
}

#[bench]
fn multiply_4(b: &mut Bencher) {
multiply_bench(b, 1 << 12, 1 << 13);
}

#[bench]
fn multiply_5(b: &mut Bencher) {
multiply_bench(b, 1 << 12, 1 << 14);
}

#[bench]
fn divide_0(b: &mut Bencher) {
divide_bench(b, 1 << 8, 1 << 6);
Expand Down
83 changes: 82 additions & 1 deletion src/biguint/multiplication.rs
Expand Up @@ -88,9 +88,10 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) {
let acc = acc;
let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };

// We use three algorithms for different input sizes.
// We use four algorithms for different input sizes.
//
// - For small inputs, long multiplication is fastest.
// - If y is at least least twice as long as x, split using Half-Karatsuba.
// - Next we use Karatsuba multiplication (Toom-2), which we have optimized
// to avoid unnecessary allocations for intermediate values.
// - For the largest inputs we use Toom-3, which better optimizes the
Expand All @@ -104,6 +105,86 @@ fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) {
for (i, xi) in x.iter().enumerate() {
mac_digit(&mut acc[i..], y, *xi);
}
} else if x.len() * 2 <= y.len() {
// Karatsuba Multiplication for factors with significant length disparity.
//
// The Half-Karatsuba Multiplication Algorithm is a specialized case of
// the normal Karatsuba multiplication algorithm, designed for the scenario
// where y has at least twice as many base digits as x.
//
// In this case y (the longer input) is split into high2 and low2,
// at m2 (half the length of y) and x (the shorter input),
// is used directly without splitting.
//
// The algorithm then proceeds as follows:
//
// 1. Compute the product z0 = x * low2.
// 2. Compute the product temp = x * high2.
// 3. Adjust the weight of temp by adding m2 (* NBASE ^ m2)
// 4. Add temp and z0 to obtain the final result.
//
// Proof:
//
// The algorithm can be derived from the original Karatsuba algorithm by
// simplifying the formula when the shorter factor x is not split into
// high and low parts, as shown below.
//
// Original Karatsuba formula:
//
// result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0
//
// Substitutions:
//
// low1 = x
// high1 = 0
//
// Applying substitutions:
//
// z0 = (low1 * low2)
// = (x * low2)
//
// z1 = ((low1 + high1) * (low2 + high2))
// = ((x + 0) * (low2 + high2))
// = (x * low2) + (x * high2)
//
// z2 = (high1 * high2)
// = (0 * high2)
// = 0
//
// Simplified using the above substitutions:
//
// result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0
// = (0 * NBASE ^ (m2 × 2)) + ((z1 - 0 - z0) * NBASE ^ m2) + z0
// = ((z1 - z0) * NBASE ^ m2) + z0
// = ((z1 - z0) * NBASE ^ m2) + z0
// = (x * high2) * NBASE ^ m2 + z0
let m2 = y.len() / 2;
let (low2, high2) = y.split_at(m2);

// We reuse the same BigUint for all the intermediate multiplies and have to size p
// appropriately here: x.len() and high2.len() could be different:
let len = x.len() + high2.len() + 1;
let mut p = BigUint { data: vec![0; len] };

// z0 = x * low2
mac3(&mut p.data, x, low2);
p.normalize();

// Add z0 directly to the accumulator
add2(acc, &p.data);

// Zero out p before the next multiply:
p.data.truncate(0);
p.data.resize(len, 0);

// temp = x * high2
mac3(&mut p.data, x, high2);
p.normalize();

// Add temp shifted by m2 to the accumulator
// This simulates the effect of multiplying temp by b^m2.
// Add directly starting at index m2 in the accumulator.
add2(&mut acc[m2..], &p.data);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since each product is immediately added to acc, I think we don't even need the p buffer at all:

        mac3(acc, x, low2);
        mac3(&mut acc[m2..], x, high2);

It doesn't change the benchmark times for me, but still, it's an easy allocation to avoid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! Very nice. Thanks, I've pushed the simplification.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about the format error due to trailing newline, fixed and pushed.

} else if x.len() <= 256 {
// Karatsuba multiplication:
//
Expand Down