diff --git a/src/lib.rs b/src/lib.rs index 934ea29..52cfe02 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -224,16 +224,32 @@ impl Complex { if self.re.is_sign_positive() { // simple positive real √r, and copy `im` for its sign Self::new(self.re.sqrt(), self.im) - } else if self.im.is_sign_positive() { - // √(r e^(iπ)) = √r e^(iπ/2) = i√r - Self::new(T::zero(), self.re.abs().sqrt()) } else { + // √(r e^(iπ)) = √r e^(iπ/2) = i√r // √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r - Self::new(T::zero(), -self.re.abs().sqrt()) + let re = T::zero(); + let im = (-self.re).sqrt(); + if self.im.is_sign_positive() { + Self::new(re, im) + } else { + Self::new(re, -im) + } + } + } else if self.re.is_zero() { + // √(r e^(iπ/2)) = √r e^(iπ/4) = √(r/2) + i√(r/2) + // √(r e^(-iπ/2)) = √r e^(-iπ/4) = √(r/2) - i√(r/2) + let one = T::one(); + let two = one + one; + let x = (self.im.abs() / two).sqrt(); + if self.im.is_sign_positive() { + Self::new(x, x) + } else { + Self::new(x, -x) } } else { // formula: sqrt(r e^(it)) = sqrt(r) e^(it/2) - let two = T::one() + T::one(); + let one = T::one(); + let two = one + one; let (r, theta) = self.to_polar(); Self::from_polar(&(r.sqrt()), &(theta / two)) } @@ -1748,8 +1764,8 @@ mod test { assert!(close(c.conj().sqrt(), c.sqrt().conj())); // for this branch, -pi/2 <= arg(sqrt(z)) <= pi/2 assert!( - -f64::consts::PI / 2.0 <= c.sqrt().arg() - && c.sqrt().arg() <= f64::consts::PI / 2.0 + -f64::consts::FRAC_PI_2 <= c.sqrt().arg() + && c.sqrt().arg() <= f64::consts::FRAC_PI_2 ); // sqrt(z) * sqrt(z) = z assert!(close(c.sqrt() * c.sqrt(), c)); @@ -1760,11 +1776,29 @@ mod test { fn test_sqrt_real() { for n in (0..100).map(f64::from) { // √(n² + 0i) = n + 0i - assert_eq!(Complex64::new(n * n, 0.0).sqrt(), Complex64::new(n, 0.0)); + let n2 = n * n; + assert_eq!(Complex64::new(n2, 0.0).sqrt(), Complex64::new(n, 0.0)); // √(-n² + 0i) = 0 + ni - assert_eq!(Complex64::new(-n * n, 0.0).sqrt(), Complex64::new(0.0, n)); + assert_eq!(Complex64::new(-n2, 0.0).sqrt(), Complex64::new(0.0, n)); // √(-n² - 0i) = 0 - ni - assert_eq!(Complex64::new(-n * n, -0.0).sqrt(), Complex64::new(0.0, -n)); + assert_eq!(Complex64::new(-n2, -0.0).sqrt(), Complex64::new(0.0, -n)); + } + } + + #[test] + fn test_sqrt_imag() { + for n in (0..100).map(f64::from) { + // √(0 + n²i) = n e^(iπ/4) + let n2 = n * n; + assert!(close( + Complex64::new(0.0, n2).sqrt(), + Complex64::from_polar(&n, &(f64::consts::FRAC_PI_4)) + )); + // √(0 - n²i) = n e^(-iπ/4) + assert!(close( + Complex64::new(0.0, -n2).sqrt(), + Complex64::from_polar(&n, &(-f64::consts::FRAC_PI_4)) + )); } }