矩阵连乘问题

本文最后更新于:2024年3月21日 凌晨

矩阵连乘问题

问题描述

  • 矩阵连乘积的最优计算次序问题,即对于给定的相继 n 个矩阵{A1, A2, A3,…, An}(其中,矩阵 Ai的维数为为 pi-1 x pi, i=1,2,…, n),如何确定计算矩阵连乘积 A1, A2, A3,…, An的计算次序(完全加括号方式),使得依此次序计算矩阵连乘积需要的数乘次数最少。

矩阵乘法

  • 矩阵 A 和 B 可乘的条件是矩阵 A 的列数等于矩阵 B 的行数,若 A 是一个 p x q 矩阵, B 是 q x r 矩阵,则其乘积 C=AB 是一个 p x r 矩阵,在上述计算 C 的标准算法中,主要计算量为三重循环,总共需要 pqr 次数乘。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
public class Solution {

public static void main(String[] args) {
int[][] a = {{1, 2}, {1, 2}};
int[][] b = {{1, 2}, {1, 2}};
int[][] c = new int[2][2];
new Solution().matrixMultiply(a, b, c, a.length, a[0].length, b.length, b[0].length);
for (int[] ints : c) {
System.out.println(Arrays.toString(ints));
}
}

public void matrixMultiply(int[][] a, int[][] b, int[][] c, int ra, int ca, int rb, int cb) {
//a与b矩阵不满足可乘条件。
if (ca != rb) System.exit(1);
for (int i = 0; i < ra; i++) {
for (int j = 0; j < cb; j++) {
//a矩阵的行元素与b矩阵的列元素相乘并求和。
int sum = a[i][0] * b[0][j];
for (int k = 1; k < ca; k++) {
sum += a[i][k] * b[k][j];
}
c[i][j] = sum;
}
}
}
}


  • matrixMultiply(int[][] a, int[][] b, int[][] c, int ra, int ca, int rb, int cb):计算矩阵的乘积。
    • int[][] a/int[][] b:待相乘的矩阵。
    • int[][] c:存放结果的矩阵。
    • int ra/int ca:矩阵 a 的行和列。
    • int rb/int cb:矩阵 b 的行和列。

算法分析

  • 对于矩阵连乘积的最优计算次序问题,设计算 A[i:j], i < = k < j,所需要的最少数乘次数为 m[i][j],则原问题的最优值为 m[1][n]
  • i = j 时, A[i:j] = A,为单一矩阵,无需计算,因此 m[i][i] = 0, `i = 1,2,3,…,n
  • i < j 时,可利用最优子结构性质来计算 m[i][j],事实上,若计算 A[i:j] 的最优次序在 Ak和 Ak+1之间断开, i <= k < j,则m[i][j] = min{m[i][k] + m[k+1][j] + pi-1pkpj}, i < j
  • 若将对应于 m[i][j] 的断开位置 k 记为 s[i][j],在计算出最优值 m[i][j] 后,可递归地由 s[i][j] 构造出相应的最优解。

代码实现

流程图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
public class Solution {

public static void main(String[] args) {
int[] p = {30, 35, 15, 5, 10, 20, 25};
int n = p.length - 1;
int[][] m = new int[n + 1][n + 1];
int[][] s = new int[n + 1][n + 1];
Solution solution = new Solution();
solution.matrixChain(p, n, m, s);

System.out.println("---------------m[i][j]-----------------");
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
System.out.printf("%6d", m[i][j]);
}
System.out.println();
}
System.out.println("---------------s[i][j]-----------------");
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
System.out.printf("%6d", s[i][j]);
}
System.out.println();
}
System.out.println("最优值:" + m[1][n]);
System.out.println("---------------最优解-----------------");
traceback(1, n, s);
System.out.println("计算次数:" + solution.count);
}

public int count = 0;

public static void traceback(int i, int j, int[][] s) {
if (i == j) return;
traceback(i, s[i][j], s);
traceback(s[i][j] + 1, j, s);
System.out.print("Multiply A[" + i + "][" + s[i][j] + "]");
System.out.println(" and A[" + (s[i][j] + 1) + "][" + j + "]");
}

public void matrixChain(int[] p, int n, int[][] m, int[][] s) {
//m[i][i]为单一矩阵,计算量为0
for (int i = 1; i <= n; i++) {
m[i][i] = 0;
}
//r为矩阵链的长度。
for (int r = 2; r <= n; r++) {
// 以r为长度计算矩阵连乘积。
for (int i = 1; i <= n - r + 1; i++) {
int j = i + r - 1;
// 开始划分矩阵链。
m[i][j] = m[i][i] + m[i + 1][j] + p[i - 1] * p[i] * p[j];
count++;
s[i][j] = i;
for (int k = i + 1; k < j; k++) {
int t = m[i][k] + m[k + 1][j] + p[i - 1] * p[k] * p[j];
count++;
// 将最小数乘次数的信息保存。
if (t < m[i][j]) {
m[i][j] = t;
s[i][j] = k;
}
}
}
}
}
}
  • matrixChain(int[] p, int n, int[][] m, int[][] s):计算最优矩阵连乘积。
    • int[] p:矩阵链中矩阵的维数,相继 n 个矩阵{A1, A2, A3,…, An}(其中,矩阵 Ai的维数为为 pi-1 x pi, i=1,2,…, n)
    • int n:矩阵链的长度。
    • int[][] m:存放矩阵连乘积的大小, m[i][j] 表示 A[i: j]的最小矩阵连乘积。
    • int[][] s:存放达成矩阵连乘积最小时划分点的位置,s[i][j] 表示 m[i][j] 中 k 的位置。
  • traceback(int i, int j, int[][] s):计算最优计算次序。
    • int i:矩阵链的起始位置。
    • int j:矩阵链的结束位置。
    • int[][] s:存放达成矩阵连乘积最小时划分点的位置,s[i][j] 表示 m[i][j] 中 k 的位置。

递归实现矩阵连乘算法

代码实现

流程图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
public class Solution {

public static void main(String[] args) {
int[] p = {30, 35, 15, 5, 10, 20, 25};
int n = p.length - 1;
int[][] s = new int[n + 1][n + 1];
Solution solution = new Solution();
int x = solution.matrixChain(p, s, 0, n);
System.out.println("---------------s[i][j]-----------------");
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
System.out.printf("%6d", s[i][j]);
}
System.out.println();
}
System.out.println("最优值:" + x);
System.out.println("---------------最优解-----------------");
solution.traceback(1, n, s);
System.out.println("计算次数:" + solution.count);

}

public int count = 0;

public void traceback(int i, int j, int[][] s) {
if (i == j) return;
traceback (i, s[i][j], s);
traceback (s[i][j] + 1, j, s);
System.out.print ("Multiply A[" + i + "][" + s[i][j] + "]");
System.out.println (" and A[" + (s[i][j] + 1) + "][" + j + "]");
}

public int matrixChain (int[] p, int[][] s, int left, int right) {
if (left + 1 >= right) return 0;
int i = left + 1;
int min = matrixChain (p, s, left, i) + matrixChain (p, s, i, right) + p[left] * p[i] * p[right];
count++;
int minIndex = i;
for (i = left + 2; i < right; i++) {
int t = matrixChain (p, s, left, i) + matrixChain (p, s, i, right) + p[left] * p[i] * p[right];
count++;
// 将最小数乘次数的信息保存。
if (t < min) {
min = t;
minIndex = i;
}
}
s[left + 1][right] = minIndex;
return min;
}

}
  • matrixChain (int[] p, int[][] s, int left, int right):计算最优矩阵连乘积。
    • int[] p:矩阵链中矩阵的维数,相继 n 个矩阵{A1, A2, A3,…, An}(其中,矩阵 Ai的维数为为 pi-1 x pi, i=1,2,…, n)
    • int[][] s:存放达成矩阵连乘积最小时划分点的位置,s[i][j]表示m[i][j]中 k 的位置。
    • int left:表示求解的上限。
    • int left:表示求解的下限。
  • traceback (int i, int j, int[][] s):计算最优计算次序。
    • int i:矩阵链的起始位置。
    • int j:矩阵链的结束位置。
    • int[][] s:存放达成矩阵连乘积最小时划分点的位置,s[i][j]表示m[i][j]中 k 的位置。

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!