HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMinor

Listing the representations of integer M as a sum of three squares M = x^2 + y^2 + z^2

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
threethesquareslistingrepresentationssuminteger

Problem

I am trying a number as the sum of 3 squares. For my particular project I will require a complete list since the next step is to compute nearest neighbors (here).

In any case we'd like to find \$ M = x^2 + y^2 + z^2 \$ in all possible ways. Here's what I wrote in Python:

M = 1062437
m = int(np.sqrt(M))

x = np.arange(int(np.sqrt(m)))
y = M - x**2

sq = [(i,j, int(np.sqrt(M - i**2 - j**2))) for i in range(m) for j in range(m) if i**2 + j**2 in y]


This is easy to explain in pseudocode.

  • build a list of remainders \$\{ M - k^2 : 0 \leq k \leq \sqrt{M} \}\$



  • loop over all possible \$0



This is one of several possible strategies. In my case, I don't even need a fixed M it could be a range, such as 106 < M < 106 + 10*4.

I do not mind looping over values of M at this point. Can we recycle our computations somehow?

Solution

You might be interested in this question.

While looking for solution to

$$ x^{2} + y^{2} + z^{2} = M $$

you can assume

$$ x \leq y \leq z$$.

and perform rotation if needed afterward.

This is good because, you can start iterating over x in a smaller range. You know that

$$ M = x^{2} + y^{2} + z^{2} \ge x^{2} + x^{2} + x^{2} = 3 * x^{2}$$

, therefore $$x \le sqrt(M / 3)$$

Similarly, when iterating over y, you can start at xand go up to a smaller higer bound because we have :

$$ M = x^{2} + y^{2} + z^{2} \ge x^{2} + y^{2} + y^{2} = x^{2} + 2 * y^{2}$$

, therefore $$x \le y \le sqrt((M - x^{2})/ 2)$$

Then, when once x and y are fixed, you don't need to iterate over z : just compute

$$ z = sqrt(M - x^{2} - y^{2}) $$

and check that it is indeed an integer.

Here's a proof-of-concept piece of code : I've altered your code a bit to be able to compare the results you originally had with the one I get :

sq = set(tuple(sorted([i,j, int(np.sqrt(M - i**2 - j**2))])) for i in range(m) for j in range(m) if i**2 + j**2 in y)
print(len(sq))

sol = []
lim_x = int(np.sqrt(M/3))
for x in range(1 + lim_x):
    rem_x = M - x*x
    lim_y = int(np.sqrt(rem_x / 2))
    for y in range(x, 1 + lim_y):
        rem_y = rem_x - y*y
        z = int(np.sqrt(rem_y))
        if z*z == rem_y:
            assert x <= y <= z
            assert x*x + y*y + z*z == M
            sol.append((x, y, z))
print(len(sol))

assert all(s in sol for s in sq)


You have 26 different ordered solutions out of the 290 I found (in a smaller time range).

Finally, you might be able to adapt the equations above (and the code written from it) if you want to handle ranges of M values.

Code Snippets

sq = set(tuple(sorted([i,j, int(np.sqrt(M - i**2 - j**2))])) for i in range(m) for j in range(m) if i**2 + j**2 in y)
print(len(sq))


sol = []
lim_x = int(np.sqrt(M/3))
for x in range(1 + lim_x):
    rem_x = M - x*x
    lim_y = int(np.sqrt(rem_x / 2))
    for y in range(x, 1 + lim_y):
        rem_y = rem_x - y*y
        z = int(np.sqrt(rem_y))
        if z*z == rem_y:
            assert x <= y <= z
            assert x*x + y*y + z*z == M
            sol.append((x, y, z))
print(len(sol))

assert all(s in sol for s in sq)

Context

StackExchange Code Review Q#71782, answer score: 3

Revisions (0)

No revisions yet.