# [Square Digit Chains](https://projecteuler.net/problem=92)

It's easy enough to write a function that calculates the sum of the squares of the digits of a number.

In [1]:
def sum_of_digit_squares(n):
    total = 0
    while n != 0:
        total += (n % 10)^2
        n //= 10
    
    return total

We can also write a memoized recursive function for determining whether a number arrives at 1 or 89 eventually.

In [2]:
from functools import cache

@cache
def arrives_at(n):
    if n == 0:
        return 0
    
    if n == 1:
        return 1
    
    if n == 89:
        return 89
    
    return arrives_at(sum_of_digit_squares(n))

At this point, we could just evaluate `arrives_at` at every number from 1 to 10000000 and see how many starting numbers arrive at 89, but that's a little slow. To speed things up, observe that any pair of numbers with digits that are permutations of each other will arrive at the same number. For example, the sums of the squares of the digits of 112345 and 523141 are both 56, so they will both arrive at the same number (89), so we don't need to check both.

So if we only check each distinct permutation of digits once, how many starting numbers does that leave us with? Well, it's the number of [multisets](https://en.wikipedia.org/wiki/Multiset) of (0,1,2,3,4,5,6,7,8,9) of cardinality 7, which is given by
$${10+7-1 \choose 7} = 11440$$
Much better than $10^7$.

However - continuing with the above example - even if we only check if 112345 arrives at 89, since all its distinct permutations also arrive at 89, we need to include all of them in our final total. Since there could be repeated digits, this isn't as simple as 7!. Fortunately, we can find the number of distinct permutations with the [multinomial coefficient](https://en.wikipedia.org/wiki/Multinomial_theorem). If the digit 0 shows up $k_0$ times in a seven digit number, the digit 1 $k_1$ times, and so on, the number of distinct permutations of the digits is
$${7 \choose {k_0,k_1,k_2,\ldots,k_9}} = \frac{7!}{k_0! k_1! k_2!\cdots k_9!}$$

In [3]:
from itertools import combinations_with_replacement
from collections import Counter

total = 0
for digits in combinations_with_replacement(range(0, 10), 7):
    n = sum(10^k * d for (k, d) in enumerate(reversed(digits)))
    if arrives_at(n) == 89:
        c = Counter(digits)
        total += multinomial(c.values())

total

8581146

## Related sequences
* Numbers we iterate over: [A009994](https://oeis.org/A009994)