patternjavaMinor
Java Matrix multiplication using Multithreading
Viewed 0 times
multiplicationjavausingmultithreadingmatrix
Problem
I am trying to achieve matrix multiplication using the concept of concurrency and cyclic barrier.
Here is the code snippet:-
```
import Jama.Matrix;
public class Application {
public final CyclicBarrier cyclicBarrier;
public Matrix A = new Matrix(200,120);
public Matrix B = new Matrix(120,100);
public Matrix C = new Matrix(200,100);
public Application() {
initializeMatrix(A,A.getRowDimension(),A.getColumnDimension());
initializeMatrix(B,B.getRowDimension(),B.getColumnDimension());
cyclicBarrier = new CyclicBarrier(4, new Runnable() {
@Override
public void run() {
display(C);
}
});
new Thread(new Task(50,25)).start();
new Thread(new Task(100,50)).start();
new Thread(new Task(150,75)).start();
new Thread(new Task(200,100)).start();
}
public void initializeMatrix(Matrix X,int row,int col){
for (int i=0;i<row;i++)
for (int j=0;j<col;j++)
X.set(i,j,Math.random());
}
public void Multiply(int rowend, int col){
for (int i=0;i<rowend;i++){
for (int j=0;j<col;j++){
double sum = 0;
for (int k=0;k<B.getRowDimension();k++){
sum += A.get(i,k)*B.get(k,j);
}
C.set(i,j,sum);
}
}
display(C);
}
public void display(Matrix X){
System.out.print("C = [");
for (int i=0;i<X.getRowDimension();i++){
for (int j=0;j<X.getColumnDimension();j++){
System.out.print(X.get(i,j) + ",");
}
System.out.println();
}
System.out.print(" ]\n\n");
}
private class Task implements Runnable{
int row = 0;
int col = 0;
public Task(int row, int col){
this.row = row;
this.col = col;
}
@Override
public void run() {
Multiply(row, col);
try {
cyclicBarrier.await();
} catch (InterruptedException iex) {
System.out.println(iex.getMessage());
} catch (BrokenBarrierException bbex) {
Here is the code snippet:-
```
import Jama.Matrix;
public class Application {
public final CyclicBarrier cyclicBarrier;
public Matrix A = new Matrix(200,120);
public Matrix B = new Matrix(120,100);
public Matrix C = new Matrix(200,100);
public Application() {
initializeMatrix(A,A.getRowDimension(),A.getColumnDimension());
initializeMatrix(B,B.getRowDimension(),B.getColumnDimension());
cyclicBarrier = new CyclicBarrier(4, new Runnable() {
@Override
public void run() {
display(C);
}
});
new Thread(new Task(50,25)).start();
new Thread(new Task(100,50)).start();
new Thread(new Task(150,75)).start();
new Thread(new Task(200,100)).start();
}
public void initializeMatrix(Matrix X,int row,int col){
for (int i=0;i<row;i++)
for (int j=0;j<col;j++)
X.set(i,j,Math.random());
}
public void Multiply(int rowend, int col){
for (int i=0;i<rowend;i++){
for (int j=0;j<col;j++){
double sum = 0;
for (int k=0;k<B.getRowDimension();k++){
sum += A.get(i,k)*B.get(k,j);
}
C.set(i,j,sum);
}
}
display(C);
}
public void display(Matrix X){
System.out.print("C = [");
for (int i=0;i<X.getRowDimension();i++){
for (int j=0;j<X.getColumnDimension();j++){
System.out.print(X.get(i,j) + ",");
}
System.out.println();
}
System.out.print(" ]\n\n");
}
private class Task implements Runnable{
int row = 0;
int col = 0;
public Task(int row, int col){
this.row = row;
this.col = col;
}
@Override
public void run() {
Multiply(row, col);
try {
cyclicBarrier.await();
} catch (InterruptedException iex) {
System.out.println(iex.getMessage());
} catch (BrokenBarrierException bbex) {
Solution
Your basic idea is right: you can multiply matrices with heavy multithreading because each new value can be computed as a scalar product of vectors independently from the others. But I heavily doubt you need three nested loops for a scalar multiplication. Here's my approach (it is not tested and you should assure that
This should speed up things up a lot already. If you look at this implementation, this method is the inner loop of the multiplication. The outer ones can be done concurrently and normally you use an
You already did the right thing to use a runnable and pass the position of the cell you want to calculate, but your Runnable shouldn't wait at the end. In my case, the Runnable is created implicitly through lambdas. After having submitted all tasks, the program waits for them to terminate and then prints the result.
I don't know your implementation of the Matrix, but please make sure it can be edited from multiple places concurrently without messing up. Modifying does not have to be completely thread safe because each cell is only changed from one thread. If you are using an array internally to store the data and the
B.getRowDimension() == A.getColumnDimension() as well as the dimensions of C before):public void multiply(int row, int col) {
int sum = 0;
for (int r = 0; r < B.getRowDimension(); r++)
sum += A.get(row, r) * B.get(r, col);
C.set(row, col, sum);
}This should speed up things up a lot already. If you look at this implementation, this method is the inner loop of the multiplication. The outer ones can be done concurrently and normally you use an
ExecutorService and Runnable to do this.ExecutorService executor = Executors.newCachedThreadPool();
for (int x = 0; x multiply(row, col));
}
executor.shutdown();
executor.awaitTermination(Long.MAX_VALUE, TimeUnit.DAYS);
display(C);You already did the right thing to use a runnable and pass the position of the cell you want to calculate, but your Runnable shouldn't wait at the end. In my case, the Runnable is created implicitly through lambdas. After having submitted all tasks, the program waits for them to terminate and then prints the result.
I don't know your implementation of the Matrix, but please make sure it can be edited from multiple places concurrently without messing up. Modifying does not have to be completely thread safe because each cell is only changed from one thread. If you are using an array internally to store the data and the
set(r, c, v) method does not modify any states apart from values in the array, it should be fine.Code Snippets
public void multiply(int row, int col) {
int sum = 0;
for (int r = 0; r < B.getRowDimension(); r++)
sum += A.get(row, r) * B.get(r, col);
C.set(row, col, sum);
}ExecutorService executor = Executors.newCachedThreadPool();
for (int x = 0; x < C.getRowDimension(); x++)
for (int y = 0; y < C.getColumnDimension(); y++) {
final int row = x, col = y;
executor.submit(() -> multiply(row, col));
}
executor.shutdown();
executor.awaitTermination(Long.MAX_VALUE, TimeUnit.DAYS);
display(C);Context
StackExchange Code Review Q#159724, answer score: 2
Revisions (0)
No revisions yet.