diff --git a/src/lib.rs b/src/lib.rs index c52d14d..2c5dae1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -220,10 +220,23 @@ impl Complex { /// The branch satisfies `-π/2 ≤ arg(sqrt(z)) ≤ π/2`. #[inline] pub fn sqrt(&self) -> Self { - // formula: sqrt(r e^(it)) = sqrt(r) e^(it/2) - let two = T::one() + T::one(); - let (r, theta) = self.to_polar(); - Self::from_polar(&(r.sqrt()), &(theta / two)) + if self.im.is_zero() { + if self.re.is_sign_positive() { + // simple positive real √r, and copy `im` for its sign + Self::new(self.re.sqrt(), self.im) + } else if self.im.is_sign_positive() { + // √(r e^(iπ)) = √r e^(iπ/2) = i√r + Self::new(T::zero(), self.re.abs().sqrt()) + } else { + // √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r + Self::new(T::zero(), -self.re.abs().sqrt()) + } + } else { + // formula: sqrt(r e^(it)) = sqrt(r) e^(it/2) + let two = T::one() + T::one(); + let (r, theta) = self.to_polar(); + Self::from_polar(&(r.sqrt()), &(theta / two)) + } } /// Raises `self` to a floating point power. @@ -1690,6 +1703,18 @@ mod test { } } + #[test] + fn test_sqrt_real() { + for n in (0..100).map(f64::from) { + // √(n² + 0i) = n + 0i + assert_eq!(Complex64::new(n * n, 0.0).sqrt(), Complex64::new(n, 0.0)); + // √(-n² + 0i) = 0 + ni + assert_eq!(Complex64::new(-n * n, 0.0).sqrt(), Complex64::new(0.0, n)); + // √(-n² - 0i) = 0 - ni + assert_eq!(Complex64::new(-n * n, -0.0).sqrt(), Complex64::new(0.0, -n)); + } + } + #[test] fn test_sin() { assert!(close(_0_0i.sin(), _0_0i));