HiveBrain v1.2.0
Get Started
← Back to all entries
patternMinor

Optimizing a sum of matrix chains

Submitted by: @import:stackexchange-cs··
0
Viewed 0 times
chainsmatrixoptimizingsum

Problem

Edit Jan 31: important special case is when the sums form a nested structure, search for "Hasse diagram is a tree" below

Here's a practically relevant variation on matrix chain problem:

Find optimal way to compute a sum over all weighted paths in a graph, where the weight of each path is the matrix product of edge labels (ie, a matrix chain)

For instance, take $Q$ corresponding to the sum of
$$Q=A_0 A_1 A_2 A_3 A_4 +A_0 A_1 A_3 A_4 +A_0 A_1 A_4$$

We can represent it as the following sum over weighted paths

Now, we can count the number of matrix multiplications (in red) involved in computing this sum

Here's a more efficient way to compute $Q$

$$Q=A_0A_1(A_2+I)A_3A_4 + A_0A_1A_4$$

We can view it as following sum over paths:

Some edges are labeled with I, corresponding to multiplication by the identity matrix.

Given a list of matrix dimensions $d_0,\ldots,d_n$ corresponding to matrices $A_0,\ldots,A_{n-1}$ and a list of paths, the task is to figure out a sequence of matrix multiplications and additions which produces $Q$ using the smallest total number of scalar multiplications.
Specifying Paths

Each term in the sum is specified as a pair of two numbers $(i,j)$, indicating that matrices $A_{i+1},A_{i+2},\ldots,A_j$ are not present in the term.

For instance, for problem above, paths are $[(), (1,2), (1,3)]$. The second term is missing matrices $\{A_2\}$ and the third term is missing $\{A_2,A_3\}$.
Viewing paths as connected subgraphs of the chain graph, paths are partially ordered using subgraph relation. Therefore, the set of paths forms a lattice.

An important special case is when the Hasse diagram is a tree.

For example, consider this sum.

$$W=A_0 A_1 A_3 A_4 + A_0 A_1 A_2 A_4 + A_0 A_1 A_4 + A_0 A_1 A_2 A_3$$

And the corresponding Hasse diagram:

Coming back to original example $Q$, we can order matrix products in the following way:
$$Q=(A_0 A_1) A_4 + (A_0 A_1) (A_3 A_4) + ((A_0 A_1) A_2) (A_3 A_4)$$

Notice that some terms like $A_0 A_1$ are repeat

Solution

One idea would be to generalize the O(N^3) DP for the case of a single path without skips to your case:

d[i][j] would be the cost of computing all the products of matrices [A_i ... A_j) for all the paths for which all their missing matrices lie fully within the range. It can be done by iterating over a location to split the range [i, j), recursively computing the best cost for the left and the right part. If some subtree lies entirely on one side of the split point, then I believe it just works. If the split point is in the middle of some subtree (and there will be at most one such subtree), the DP needs to be called recursively for the subtree, and the result needs to be added to the cost.

Below I have the code, without the proper handing of subtrees. For your particular example if already improves from 7 to 6 multiplications.
...

def compute_schedule(tensors, paths, cache, l = 0, r = -1):
if r == -1:
r = len(tensors)

if r == l + 1:
return tensors[l], "A%s" % l

if (l, r) in cache:
return cache[(l, r)]

ret = None
for mid in range(l + 1, r):
lm, ls = compute_schedule(tensors, paths, cache, l, mid)
rm, rs = compute_schedule(tensors, paths, cache, mid, r)
cur = [lm @ rm, "%s @ %s" % (ls, rs)]

# If the Hasse diagram is a tree, the
mid` will lie within at most one
# child. The correct thing to do here would be to call the DP for such
# subtree.
# I don't do it here, and instead just compute the naive product from
# left to right for all the runs in such subtree. This is not optimal
# in general case.
need_parens = False
for path in paths:
if l = path[0] and mid = path[1]:
t = tensors[l]
ts = "A%s" % l
for k in range(l + 1, r):
if k = path[1]:
t @= tensors[k] # don't actually need to multiply tensors here, enough to compute the shape
ts += " @ A%s" % k
cur[0] += t
cur[1] += " + %s" % ts
need_parens = True

if need_parens:
cur[1] = "(" + cur[1] + ")"

cur = tuple(cur)
if ret is None or cost(cur[0])

Yields

A0 @ ((A1 @ A2 + A1) @ A3 + A1) @ A4
Schedule2 requires 4 scalar multiplications

Code Snippets

A0 @ ((A1 @ A2 + A1) @ A3 + A1) @ A4
Schedule2 requires 4 scalar multiplications

Context

StackExchange Computer Science Q#147773, answer score: 3

Revisions (0)

No revisions yet.