algorithmhaskellmathnumber-theory

Most efficient algorithm to find integer points within an ellipsoid


I'm trying to find all the integer lattice points within various 3D ellipsoids. I would like my program to take an integer N, and count all the lattice points within the ellipsoids of the form ax^2 + by^2 + cz^2 = n, where a, b, c are fixed integers and n is between 1 and N. This program should then return N tuples of the form (n, numlatticePointsWithinEllipse n).

I'm currently doing it by counting the points on the ellipsoids ax^2 + by^2 + cz^2 = m, for m between 0 and n inclusive, and then summing over m. I'm also only looking at x, y and z all positive initially, and then adding in the negatives by permuting their signs later. Ideally, I'd like to reach numbers of N = 1,000,000+ within the scale of hours

Taking a specific example of x^2 + y^2 + 3z^2 = N, here's the Haskell code I'm currently using:

import System.Environment

isqrt :: Int -> Int
isqrt 0 = 0
isqrt 1 = 1
isqrt n = head $ dropWhile (\x -> x*x > n) $ iterate (\x -> (x + n `div` x) `div` 2) (n `div` 2)

latticePointsWithoutNegatives :: Int -> [[Int]]
latticePointsWithoutNegatives 0 = [[0,0,0]]
latticePointsWithoutNegatives n = [[x,y,z] | x<-[0.. isqrt n], y<- [0.. isqrt (n - x^2)], z<-[max 0 (isqrt ((n-x^2 -y^2) `div` 3))], x^2 +y^2 + z^2 ==n]

latticePoints :: Int -> [[Int]]
latticePoints n = [ zipWith (*) [x1,x2,x3] y |  [x1,x2,x3] <- (latticePointsWithoutNegatives n), y <- [[a,b,c] | a <- (if x1 == 0 then [0] else [-1,1]), b<-(if x2 == 0 then [0] else [-1,1]), c<-(if x3 == 0 then [0] else [-1,1])]]

latticePointsUpTo :: Int -> Int
latticePointsUpTo n = sum [length (latticePoints x) | x<-[0..n]]

listResults :: Int -> [(Int, Int)]
listResults n = [(x, latticePointsUpTo x) | x<- [1..n]]

main = do
    args <- getArgs
    let cleanArgs = read (head args)
    print (listResults cleanArgs)

I've compiled this with

ghc -O2 latticePointsTest

but using the PowerShell "Measure-Command" command, I get the following results:

Measure-Command{./latticePointsTest 10}
TotalMilliseconds : 12.0901

Measure-Command{./latticePointsTest 100}
TotalMilliseconds : 12.0901

 Measure-Command{./latticePointsTest 1000}
TotalMilliseconds : 31120.4503

and going any more orders of magnitude up takes us onto the scale of days, rather than hours or minutes.

Is there anything fundamentally wrong with the algorithm I'm using? Is there any core reason why my code isn't scaling well? Any guidance will be greatly appreciated. I may also want to process the data between latticePoints and latticePointsUpTo, so I can't just rely entirely on clever number-theoretic counting techniques - I need the underlying tuples preserved.


Solution

  • Some things I would try:

    isqrt is not efficient for the range of values you are working work. Simply use the floating point sqrt function:

    isqrt = floor $ sqrt ((fromIntegral n) :: Double)
    

    Alternatively, instead of computing integer square roots, use logic like this in your list comprehensions:

    x <- takeWhile (\x -> x*x <= n) [0..],
    y <- takeWhile (\y -> y*y <= n - x*x) [0..]
    

    Also, I would use expressions like x*x instead of x^2.

    Finally, why not compute the number of solutions with something like this:

    sols a b c n =
      length [ () | x <- takeWhile (\x -> a*x*x <= n) [0..]
                  , y <- takeWhile (\y -> a*x*x+b*y*y <= n) [0..]
                  , z <- takeWhile (\z -> a*x*x+b*y*y+c*z*z <= n) [0..]
             ]
    

    This does not exactly compute the same answer that you want because it doesn't account for positive and negative solutions, but you could easily modify it to compute your answer. The idea is to use one list comprehension instead of iterating over various values of n and summing.

    Finally, I think using floor and sqrt to compute the integral square root is completely safe in this case. This code verifies that the integer square root by sing sqrt of (x*x) == x for all x <= 3037000499:

    testAll :: Int -> IO ()
    testAll n =
      print $ head [ (x,a) | x <- [n,n-1 .. 1], let a = floor $ sqrt (fromIntegral (x*x) :: Double), a /= x ]
    
    main = testAll 3037000499
    

    Note I am running this on a 64-bit GHC - otherwise just use Int64 instead of Int since Doubles are 64-bit in either case. Takes only a minute or so to verify.

    This shows that taking the floor of sqrt y will never result in the wrong answer if y <= 3037000499^2.