HiveBrain v1.2.0
Get Started
← Back to all entries
patterncppMinor

Optimizing a function in Eigen

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
functionoptimizingeigen

Problem

I'm a beginner in C++ and I would appreciate advices to optimize the following function I wrote with Eigen (in fact, to be used with RcppEigen).

So far, I observe a 3.5x speed-up compared to the corresponding function written in R, and I was wondering whether I could gain more.

Note that I'm working with very large matrices, so I rely on maps to avoid copies from the corresponding objects in R.

Thanks in advance for your help!

#include 

using namespace Rcpp; 
using namespace Eigen;

typedef Map MapArr1D;
typedef Map MapArr2D;
typedef Map MapMat;
typedef Map MapVec;

// [[Rcpp::depends(RcppEigen)]]

// [[Rcpp::export]]
void myFct(const MapMat M1, const MapMat M2, MapMat M3, MapMat M4, MapArr2D A1, 
           MapArr2D A2, const MapArr1D a1, const MapArr1D a2, const MapArr1D a3, 
           const MapArr1D a4, const MapArr1D a5, const double d1) {

 for (int j = 0; j < M1.cols(); ++j) {

    M4.noalias() -= M1.col(j) * M3.row(j);

    A1.row(j) = a1 * a2 * ((M2 - M4).transpose() * M1.col(j)).array();

    A2.row(j) = exp(-Fct(a3(j) - a4(j) - a5 / 2 - d1 / 2 - 
      pow(A1.row(j).transpose(), 2) / (2 * a1) - log(a1) / 2));

    M3.row(j) = A1.row(j) * A2.row(j);

    M4.noalias() += M1.col(j) * M3.row(j);

  }
}


where Fct is some other function.

Solution

It's hard to help without knowing the sizes, but:

-
The first thing to do would be to measure the relative cost of each of the 5 statements to see where the bottleneck is.

-
-(a5/2 + d1/2 + log(a1)) could be precomputed into a temporary outside the loop.

-
Replace pow(A1.row(j).transpose(), 2) by A1.row(j).transpose().square().

-
Compilation flags might also help: make sure you enabled AVX and if supported FMA using e.g., -march=native.

-
Depending on the sizes, you might want to rewrite the expressions to benefits from faster matrix-matrix products:

'

T = (M2-M4).transpose() * M1;
for(j...)
  ...
  A1.row(j) = a1*a2 * (T.col(j) - M1.col(j).squaredNorm() * M3.row(j).transpose());
  ...

Code Snippets

T = (M2-M4).transpose() * M1;
for(j...)
  ...
  A1.row(j) = a1*a2 * (T.col(j) - M1.col(j).squaredNorm() * M3.row(j).transpose());
  ...

Context

StackExchange Code Review Q#157696, answer score: 3

Revisions (0)

No revisions yet.