Matrix-matrix multiplication的缩写为什么是GEMM?
基本的三个优化技术:
-
Using pointers
-
Loop unrolling
-
Register variables
好耳熟呀 仿佛回到学CSAPP的时候
找到了starter code里pack的实现位置。
/*
* pack one subpanel of A
*
* pack like this
* if A is row major order
*
* a c e g
* b d f h
* i k m o
* j l n p
* q r s t
*
* then pack into a sub panel
* each letter represents sequantial
* addresses in the packed result
* (e.g. a, b, c, d are sequential
* addresses).
* - down each column
* - then next column in sub panel
* - then next sub panel down (on subseqent call)
* a c e g < each call packs one
* b d f h < subpanel
* -------
* i k m o
* j l n p
* -------
* q r s t
* 0 0 0 0
*/
static inline
void packA_mcxkc_d(
int m,
int k,
double *XA,
int ldXA,
double *packA
)
{
}
/*
* --------------------------------------------------------------------------
*/
/*
* pack one subpanel of B
*
* pack like this
* if B is
*
* row major order matrix
* a b c j k l s t
* d e f m n o u v
* g h i p q r w x
*
* each letter represents sequantial
* addresses in the packed result
* (e.g. a, b, c, d are sequential
* addresses).
*
* Then pack
* - across each row in the subpanel
* - then next row in each subpanel
* - then next subpanel (on subsequent call)
*
* a b c | j k l | s t 0
* d e f | m n o | u v 0
* g h i | p q r | w x 0
*
* ^^^^^
* each call packs one subpanel
*/
static inline
void packB_kcxnc_d(
int n,
int k,
double *XB,
int ldXB, // ldXB is the original k
double *packB
)
{
}
尝试根据提示实现
开始阅读源码:
正在准备实现packing部分代码
目前confuse的点在于lda到底是个什么玩意儿
必须吐槽一下ec2 为啥出个bug就直接炸了 能不能优雅一点
将make的log落盘到了文件中后,ec2不炸了
但是测试依然通不过,继续修改
实现了packing,正在AVX2优化中。
#include "bl_config.h"
#include "bl_dgemm_kernel.h"
#define a(i, j, ld) a[ (i)*(ld) + (j) ]
#define b(i, j, ld) b[ (i)*(ld) + (j) ]
#define c(i, j, ld) c[ (i)*(ld) + (j) ]
//
// C-based micorkernel
//
void bl_dgemm_ukr( int k,
int m,
int n,
const double * restrict a,
const double * restrict b,
double *c,
unsigned long long ldc,
aux_t* data )
{
int l, j, i;
for ( l = 0; l < k; ++l )
{
for ( j = 0; j < n; ++j )
{
for ( i = 0; i < m; ++i )
{
// ldc is used here because a[] and b[] are not packed by the
// starter code
// cse260 - you can modify the leading indice to DGEMM_NR and DGEMM_MR as appropriate
//
// c( i, j, ldc ) += a( i, l, ldc) * b( l, j, ldc );
c(i, j, ldc) += a(l, i, DGEMM_MR) * b(l, j, DGEMM_NR);
}
}
}
}
// cse260
// you can put your optimized kernels here
// - put the function prototypes in bl_dgemm_kernel.h
// - define BL_MICRO_KERNEL appropriately in bl_config.h
//
void bl_dgemm_avx( int k,
int m,
int n,
const double * restrict a,
const double * restrict b,
double *c,
unsigned long long ldc,
aux_t* data )
{
int l, j, i;
if(m == DGEMM_MR && n == DGEMM_NR){
register __m256d c0 = _mm256_loadu_pd(c);
register __m256d c1 = _mm256_loadu_pd(c+ldc);
register __m256d c2 = _mm256_loadu_pd(c+2*ldc);
register __m256d c3 = _mm256_loadu_pd(c+3*ldc);
register __m256d c4 = _mm256_loadu_pd(c+4);
register __m256d c5 = _mm256_loadu_pd(c+ldc+4);
register __m256d c6 = _mm256_loadu_pd(c+2*ldc+4);
register __m256d c7 = _mm256_loadu_pd(c+3*ldc+4);
for(l=0; l<k; l++){
register __m256d b_load = _mm256_load_pd(b+l*n);
register __m256d b_load2 = _mm256_load_pd(b+l*n +4);
register __m256d a0 = _mm256_broadcast_sd ( a + l*m);
register __m256d a1 = _mm256_broadcast_sd ( a + l*m + 1 );
register __m256d a2 = _mm256_broadcast_sd ( a + l*m + 2 );
register __m256d a3 = _mm256_broadcast_sd ( a + l*m + 3 );
c0 = _mm256_fmadd_pd(a0, b_load, c0);
c1 = _mm256_fmadd_pd(a1, b_load, c1);
c2 = _mm256_fmadd_pd(a2, b_load, c2);
c3 = _mm256_fmadd_pd(a3, b_load, c3);
c4 = _mm256_fmadd_pd(a0, b_load2, c4);
c5 = _mm256_fmadd_pd(a1, b_load2, c5);
c6 = _mm256_fmadd_pd(a2, b_load2, c6);
c7 = _mm256_fmadd_pd(a3, b_load2, c7);
// c0 += a0*b_load;
// c1 += a1*b_load;
// c2 += a2*b_load;
// c3 += a3*b_load;
}
_mm256_storeu_pd(c, c0);
_mm256_storeu_pd(c+ldc, c1);
_mm256_storeu_pd(c+2*ldc, c2);
_mm256_storeu_pd(c+3*ldc, c3);
_mm256_storeu_pd(c+4, c4);
_mm256_storeu_pd(c+ldc+4, c5);
_mm256_storeu_pd(c+2*ldc+4, c6);
_mm256_storeu_pd(c+3*ldc+4, c7);
}else{
for ( l = 0; l < k; ++l )
{
for ( j = 0; j < n; ++j )
{
for ( i = 0; i < m; ++i )
{
c(i, j, ldc) += a(l, i, DGEMM_MR) * b(l, j, DGEMM_NR);
}
}
}
}
}
void bl_dgemm_avx_2_unroll_4( int k,
int m,
int n,
const double * restrict a,
const double * restrict b,
double *c,
unsigned long long ldc,
aux_t* data )
{
int l, j, i;
if(m == DGEMM_MR && n == DGEMM_NR){
for(i=0; i<m; i+=4){
for(j=0; j<n; j+=4){
register __m256d c0 = _mm256_loadu_pd(c+(0+i)*ldc+j);
register __m256d c1 = _mm256_loadu_pd(c+(1+i)*ldc+j);
register __m256d c2 = _mm256_loadu_pd(c+(2+i)*ldc+j);
register __m256d c3 = _mm256_loadu_pd(c+(3+i)*ldc+j);
for(l=0; l<=k-4; l+=4){
register __m256d b_load = _mm256_load_pd(b+l*n+j);
register __m256d a0 = _mm256_broadcast_sd ( a + l*m + i);
register __m256d a1 = _mm256_broadcast_sd ( a + l*m + 1 + i);
register __m256d a2 = _mm256_broadcast_sd ( a + l*m + 2 + i);
register __m256d a3 = _mm256_broadcast_sd ( a + l*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
b_load = _mm256_load_pd(b+(l+1)*n+j);
a0 = _mm256_broadcast_sd ( a + (l+1)*m + i);
a1 = _mm256_broadcast_sd ( a + (l+1)*m + 1 + i);
a2 = _mm256_broadcast_sd ( a + (l+1)*m + 2 + i);
a3 = _mm256_broadcast_sd ( a + (l+1)*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
b_load = _mm256_load_pd(b+(l+2)*n+j);
a0 = _mm256_broadcast_sd ( a + (l+2)*m + i);
a1 = _mm256_broadcast_sd ( a + (l+2)*m + 1 + i);
a2 = _mm256_broadcast_sd ( a + (l+2)*m + 2 + i);
a3 = _mm256_broadcast_sd ( a + (l+2)*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
b_load = _mm256_load_pd(b+(l+3)*n+j);
a0 = _mm256_broadcast_sd ( a + (l+3)*m + i);
a1 = _mm256_broadcast_sd ( a + (l+3)*m + 1 + i);
a2 = _mm256_broadcast_sd ( a + (l+3)*m + 2 + i);
a3 = _mm256_broadcast_sd ( a + (l+3)*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
}
for(; l<k; l++){
register __m256d b_load = _mm256_load_pd(b+l*n+j);
register __m256d a0 = _mm256_broadcast_sd ( a + l*m + i);
register __m256d a1 = _mm256_broadcast_sd ( a + l*m + 1 + i);
register __m256d a2 = _mm256_broadcast_sd ( a + l*m + 2 + i);
register __m256d a3 = _mm256_broadcast_sd ( a + l*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
}
_mm256_storeu_pd(c+(0+i)*ldc+j, c0);
_mm256_storeu_pd(c+(1+i)*ldc+j, c1);
_mm256_storeu_pd(c+(2+i)*ldc+j, c2);
_mm256_storeu_pd(c+(3+i)*ldc+j, c3);
}
}
}else{
for ( l = 0; l < k; ++l )
{
for ( j = 0; j < n; ++j )
{
for ( i = 0; i < m; ++i )
{
c(i, j, ldc) += a(l, i, DGEMM_MR) * b(l, j, DGEMM_NR);
}
}
}
}
}
void bl_dgemm_avx_2_unroll_8( int k,
int m,
int n,
const double * restrict a,
const double * restrict b,
double *c,
unsigned long long ldc,
aux_t* data )
{
int l, j, i;
if(m == DGEMM_MR && n == DGEMM_NR){
for(i=0; i<m; i+=4){
for(j=0; j<n; j+=4){
register __m256d c0 = _mm256_loadu_pd(c+(0+i)*ldc+j);
register __m256d c1 = _mm256_loadu_pd(c+(1+i)*ldc+j);
register __m256d c2 = _mm256_loadu_pd(c+(2+i)*ldc+j);
register __m256d c3 = _mm256_loadu_pd(c+(3+i)*ldc+j);
for(l=0; l<=k-8; l+=8){
register __m256d b_load = _mm256_load_pd(b+l*n+j);
register __m256d a0 = _mm256_broadcast_sd ( a + l*m + i);
register __m256d a1 = _mm256_broadcast_sd ( a + l*m + 1 + i);
register __m256d a2 = _mm256_broadcast_sd ( a + l*m + 2 + i);
register __m256d a3 = _mm256_broadcast_sd ( a + l*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
// register __m256d b_load_2 = _mm256_load_pd(b+(l+1)*n+j);
// register __m256d a0_2 = _mm256_broadcast_sd ( a + (l+1)*m + i);
// register __m256d a1_2 = _mm256_broadcast_sd ( a + (l+1)*m + 1 + i);
// register __m256d a2_2 = _mm256_broadcast_sd ( a + (l+1)*m + 2 + i);
// register __m256d a3_2 = _mm256_broadcast_sd ( a + (l+1)*m + 3 + i);
// c0 += a0_2*b_load_2;
// c1 += a1_2*b_load_2;
// c2 += a2_2*b_load_2;
// c3 += a3_2*b_load_2;
b_load = _mm256_load_pd(b+(l+1)*n+j);
a0 = _mm256_broadcast_sd ( a + (l+1)*m + i);
a1 = _mm256_broadcast_sd ( a + (l+1)*m + 1 + i);
a2 = _mm256_broadcast_sd ( a + (l+1)*m + 2 + i);
a3 = _mm256_broadcast_sd ( a + (l+1)*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
b_load = _mm256_load_pd(b+(l+2)*n+j);
a0 = _mm256_broadcast_sd ( a + (l+2)*m + i);
a1 = _mm256_broadcast_sd ( a + (l+2)*m + 1 + i);
a2 = _mm256_broadcast_sd ( a + (l+2)*m + 2 + i);
a3 = _mm256_broadcast_sd ( a + (l+2)*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
b_load = _mm256_load_pd(b+(l+3)*n+j);
a0 = _mm256_broadcast_sd ( a + (l+3)*m + i);
a1 = _mm256_broadcast_sd ( a + (l+3)*m + 1 + i);
a2 = _mm256_broadcast_sd ( a + (l+3)*m + 2 + i);
a3 = _mm256_broadcast_sd ( a + (l+3)*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
b_load = _mm256_load_pd(b+(l+4)*n+j);
a0 = _mm256_broadcast_sd ( a + (l+4)*m + i);
a1 = _mm256_broadcast_sd ( a + (l+4)*m + 1 + i);
a2 = _mm256_broadcast_sd ( a + (l+4)*m + 2 + i);
a3 = _mm256_broadcast_sd ( a + (l+4)*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
b_load = _mm256_load_pd(b+(l+5)*n+j);
a0 = _mm256_broadcast_sd ( a + (l+5)*m + i);
a1 = _mm256_broadcast_sd ( a + (l+5)*m + 1 + i);
a2 = _mm256_broadcast_sd ( a + (l+5)*m + 2 + i);
a3 = _mm256_broadcast_sd ( a + (l+5)*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
b_load = _mm256_load_pd(b+(l+6)*n+j);
a0 = _mm256_broadcast_sd ( a + (l+6)*m + i);
a1 = _mm256_broadcast_sd ( a + (l+6)*m + 1 + i);
a2 = _mm256_broadcast_sd ( a + (l+6)*m + 2 + i);
a3 = _mm256_broadcast_sd ( a + (l+6)*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
b_load = _mm256_load_pd(b+(l+7)*n+j);
a0 = _mm256_broadcast_sd ( a + (l+7)*m + i);
a1 = _mm256_broadcast_sd ( a + (l+7)*m + 1 + i);
a2 = _mm256_broadcast_sd ( a + (l+7)*m + 2 + i);
a3 = _mm256_broadcast_sd ( a + (l+7)*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
}
for(; l<k; l++){
register __m256d b_load = _mm256_load_pd(b+l*n+j);
register __m256d a0 = _mm256_broadcast_sd ( a + l*m + i);
register __m256d a1 = _mm256_broadcast_sd ( a + l*m + 1 + i);
register __m256d a2 = _mm256_broadcast_sd ( a + l*m + 2 + i);
register __m256d a3 = _mm256_broadcast_sd ( a + l*m + 3 + i);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
}
_mm256_storeu_pd(c+(0+i)*ldc+j, c0);
_mm256_storeu_pd(c+(1+i)*ldc+j, c1);
_mm256_storeu_pd(c+(2+i)*ldc+j, c2);
_mm256_storeu_pd(c+(3+i)*ldc+j, c3);
}
}
}else{
for ( l = 0; l < k; ++l )
{
for ( j = 0; j < n; ++j )
{
for ( i = 0; i < m; ++i )
{
c(i, j, ldc) += a(l, i, DGEMM_MR) * b(l, j, DGEMM_NR);
}
}
}
}
}
void bl_dgemm_avx_3( int k,
int m,
int n,
const double * restrict a,
const double * restrict b,
double *c,
unsigned long long ldc,
aux_t* data )
{
int l, j, i;
if(m == DGEMM_MR && n == DGEMM_NR){
for(i=0; i<m; i+=8){
for(j=0; j<n; j+=4){
register __m256d c0 = _mm256_loadu_pd(c+(0+i)*ldc+j);
register __m256d c1 = _mm256_loadu_pd(c+(1+i)*ldc+j);
register __m256d c2 = _mm256_loadu_pd(c+(2+i)*ldc+j);
register __m256d c3 = _mm256_loadu_pd(c+(3+i)*ldc+j);
register __m256d c4 = _mm256_loadu_pd(c+(4+i)*ldc+j);
register __m256d c5 = _mm256_loadu_pd(c+(5+i)*ldc+j);
register __m256d c6 = _mm256_loadu_pd(c+(6+i)*ldc+j);
register __m256d c7 = _mm256_loadu_pd(c+(7+i)*ldc+j);
for(l=0; l<k; l++){
register __m256d b_load = _mm256_load_pd(b+l*n+j);
register __m256d a0 = _mm256_broadcast_sd ( a + l*m + i);
register __m256d a1 = _mm256_broadcast_sd ( a + l*m + 1 + i);
register __m256d a2 = _mm256_broadcast_sd ( a + l*m + 2 + i);
register __m256d a3 = _mm256_broadcast_sd ( a + l*m + 3 + i);
register __m256d a4 = _mm256_broadcast_sd ( a + l*m + 4 + i);
register __m256d a5 = _mm256_broadcast_sd ( a + l*m + 5 + i);
register __m256d a6 = _mm256_broadcast_sd ( a + l*m + 6 + i);
register __m256d a7 = _mm256_broadcast_sd ( a + l*m + 7 + i);
// c0 = _mm256_fmadd_pd(a0, b_load, c0);
// c1 = _mm256_fmadd_pd(a1, b_load, c1);
// c2 = _mm256_fmadd_pd(a2, b_load, c2);
// c3 = _mm256_fmadd_pd(a3, b_load, c3);
c0 += a0*b_load;
c1 += a1*b_load;
c2 += a2*b_load;
c3 += a3*b_load;
c4 += a4*b_load;
c5 += a5*b_load;
c6 += a6*b_load;
c7 += a7*b_load;
}
_mm256_storeu_pd(c+(0+i)*ldc+j, c0);
_mm256_storeu_pd(c+(1+i)*ldc+j, c1);
_mm256_storeu_pd(c+(2+i)*ldc+j, c2);
_mm256_storeu_pd(c+(3+i)*ldc+j, c3);
_mm256_storeu_pd(c+(4+i)*ldc+j, c4);
_mm256_storeu_pd(c+(5+i)*ldc+j, c5);
_mm256_storeu_pd(c+(6+i)*ldc+j, c6);
_mm256_storeu_pd(c+(7+i)*ldc+j, c7);
}
}
}else{
for ( l = 0; l < k; ++l )
{
for ( j = 0; j < n; ++j )
{
for ( i = 0; i < m; ++i )
{
c(i, j, ldc) += a(l, i, DGEMM_MR) * b(l, j, DGEMM_NR);
}
}
}
}
}