学习记录 | CSE 260 Parallel Computing 研究

Matrix-matrix multiplication的缩写为什么是GEMM?

基本的三个优化技术:

  1. Using pointers

  2. Loop unrolling

  3. Register variables

好耳熟呀 仿佛回到学CSAPP的时候

Poorman’s BLAS

可以实现level-3 BLAS

好吧应该就是blocking

矩阵分区的思路

找到了starter code里pack的实现位置。

/* 
 * pack one subpanel of A
 *
 * pack like this 
 * if A is row major order
 *
 *     a c e g
 *     b d f h
 *     i k m o
 *     j l n p
 *     q r s t
 *     
 * then pack into a sub panel
 * each letter represents sequantial
 * addresses in the packed result
 * (e.g. a, b, c, d are sequential
 * addresses). 
 * - down each column
 * - then next column in sub panel
 * - then next sub panel down (on subseqent call)
 
 *     a c e g  < each call packs one
 *     b d f h  < subpanel
 *     -------
 *     i k m o
 *     j l n p
 *     -------
 *     q r s t
 *     0 0 0 0
 */
static inline
void packA_mcxkc_d(
        int    m,
        int    k,
        double *XA,
        int    ldXA,
        double *packA
        )
{
  
}



/*
 * --------------------------------------------------------------------------
 */

/* 
 * pack one subpanel of B
 * 
 * pack like this 
 * if B is 
 *
 * row major order matrix
 *     a b c j k l s t
 *     d e f m n o u v
 *     g h i p q r w x
 *
 * each letter represents sequantial
 * addresses in the packed result
 * (e.g. a, b, c, d are sequential
 * addresses). 
 *
 * Then pack 
 *   - across each row in the subpanel
 *   - then next row in each subpanel
 *   - then next subpanel (on subsequent call)
 *
 *     a b c |  j k l |  s t 0
 *     d e f |  m n o |  u v 0
 *     g h i |  p q r |  w x 0
 *
 *     ^^^^^
 *     each call packs one subpanel
 */
static inline
void packB_kcxnc_d(
        int    n,
        int    k,
        double *XB,
        int    ldXB, // ldXB is the original k
        double *packB
        )
{
}

尝试根据提示实现

开始阅读源码:

正在准备实现packing部分代码
目前confuse的点在于lda到底是个什么玩意儿

必须吐槽一下ec2 为啥出个bug就直接炸了 能不能优雅一点

将make的log落盘到了文件中后,ec2不炸了

但是测试依然通不过,继续修改

实现了packing,正在AVX2优化中。

#include "bl_config.h"
#include "bl_dgemm_kernel.h"

#define a(i, j, ld) a[ (i)*(ld) + (j) ]
#define b(i, j, ld) b[ (i)*(ld) + (j) ]
#define c(i, j, ld) c[ (i)*(ld) + (j) ]

//
// C-based micorkernel
//
void bl_dgemm_ukr( int    k,
		   int    m,
                   int    n,
                   const double * restrict a,
                   const double * restrict b,
                   double *c,
                   unsigned long long ldc,
                   aux_t* data )
{
    int l, j, i;

    for ( l = 0; l < k; ++l )
    {                 
        for ( j = 0; j < n; ++j )
        { 
            for ( i = 0; i < m; ++i )
            { 
	      // ldc is used here because a[] and b[] are not packed by the
	      // starter code
	      // cse260 - you can modify the leading indice to DGEMM_NR and DGEMM_MR as appropriate
	      //
	    //   c( i, j, ldc ) += a( i, l, ldc) * b( l, j, ldc );
            c(i, j, ldc) += a(l, i, DGEMM_MR) * b(l, j, DGEMM_NR);
            }
        }
    }

}


// cse260
// you can put your optimized kernels here
// - put the function prototypes in bl_dgemm_kernel.h
// - define BL_MICRO_KERNEL appropriately in bl_config.h
//

void bl_dgemm_avx( int    k,
		           int    m,
                   int    n,
                   const double * restrict a,
                   const double * restrict b,
                   double *c,
                   unsigned long long ldc,
                   aux_t* data )
{
    int l, j, i;
    if(m == DGEMM_MR && n == DGEMM_NR){
        register __m256d c0 = _mm256_loadu_pd(c);
        register __m256d c1 = _mm256_loadu_pd(c+ldc);
        register __m256d c2 = _mm256_loadu_pd(c+2*ldc);
        register __m256d c3 = _mm256_loadu_pd(c+3*ldc);
        register __m256d c4 = _mm256_loadu_pd(c+4);
        register __m256d c5 = _mm256_loadu_pd(c+ldc+4);
        register __m256d c6 = _mm256_loadu_pd(c+2*ldc+4);
        register __m256d c7 = _mm256_loadu_pd(c+3*ldc+4);
        for(l=0; l<k; l++){
            register __m256d b_load = _mm256_load_pd(b+l*n);
            register __m256d b_load2 = _mm256_load_pd(b+l*n +4);
            register __m256d a0 = _mm256_broadcast_sd ( a + l*m);
            register __m256d a1 = _mm256_broadcast_sd ( a + l*m + 1 );
            register __m256d a2 = _mm256_broadcast_sd ( a + l*m + 2 );
            register __m256d a3 = _mm256_broadcast_sd ( a + l*m + 3 );

            c0 = _mm256_fmadd_pd(a0, b_load, c0);
            c1 = _mm256_fmadd_pd(a1, b_load, c1);
            c2 = _mm256_fmadd_pd(a2, b_load, c2);
            c3 = _mm256_fmadd_pd(a3, b_load, c3);
            c4 = _mm256_fmadd_pd(a0, b_load2, c4);
            c5 = _mm256_fmadd_pd(a1, b_load2, c5);
            c6 = _mm256_fmadd_pd(a2, b_load2, c6);
            c7 = _mm256_fmadd_pd(a3, b_load2, c7);
            // c0 += a0*b_load;
            // c1 += a1*b_load;
            // c2 += a2*b_load;
            // c3 += a3*b_load;
        }
        _mm256_storeu_pd(c, c0);
        _mm256_storeu_pd(c+ldc, c1);
        _mm256_storeu_pd(c+2*ldc, c2);
        _mm256_storeu_pd(c+3*ldc, c3);
        _mm256_storeu_pd(c+4, c4);
        _mm256_storeu_pd(c+ldc+4, c5);
        _mm256_storeu_pd(c+2*ldc+4, c6);
        _mm256_storeu_pd(c+3*ldc+4, c7);
    }else{
        for ( l = 0; l < k; ++l )
        {                 
            for ( j = 0; j < n; ++j )
            { 
                for ( i = 0; i < m; ++i )
                { 
                    c(i, j, ldc) += a(l, i, DGEMM_MR) * b(l, j, DGEMM_NR);
                }
            }
        }
    }

}

void bl_dgemm_avx_2_unroll_4( int    k,
		           int    m,
                   int    n,
                   const double * restrict a,
                   const double * restrict b,
                   double *c,
                   unsigned long long ldc,
                   aux_t* data )
{
    int l, j, i;
    if(m == DGEMM_MR && n == DGEMM_NR){
        for(i=0; i<m; i+=4){
            for(j=0; j<n; j+=4){
                register __m256d c0 = _mm256_loadu_pd(c+(0+i)*ldc+j);
                register __m256d c1 = _mm256_loadu_pd(c+(1+i)*ldc+j);
                register __m256d c2 = _mm256_loadu_pd(c+(2+i)*ldc+j);
                register __m256d c3 = _mm256_loadu_pd(c+(3+i)*ldc+j);
                for(l=0; l<=k-4; l+=4){
                    register __m256d b_load = _mm256_load_pd(b+l*n+j);

                    register __m256d a0 = _mm256_broadcast_sd ( a + l*m + i);
                    register __m256d a1 = _mm256_broadcast_sd ( a + l*m + 1 + i);
                    register __m256d a2 = _mm256_broadcast_sd ( a + l*m + 2 + i);
                    register __m256d a3 = _mm256_broadcast_sd ( a + l*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;

                    b_load = _mm256_load_pd(b+(l+1)*n+j);
                    
                    a0 = _mm256_broadcast_sd ( a + (l+1)*m + i);
                    a1 = _mm256_broadcast_sd ( a + (l+1)*m + 1 + i);
                    a2 = _mm256_broadcast_sd ( a + (l+1)*m + 2 + i);
                    a3 = _mm256_broadcast_sd ( a + (l+1)*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;

                    b_load = _mm256_load_pd(b+(l+2)*n+j);
                    
                    a0 = _mm256_broadcast_sd ( a + (l+2)*m + i);
                    a1 = _mm256_broadcast_sd ( a + (l+2)*m + 1 + i);
                    a2 = _mm256_broadcast_sd ( a + (l+2)*m + 2 + i);
                    a3 = _mm256_broadcast_sd ( a + (l+2)*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;

                    b_load = _mm256_load_pd(b+(l+3)*n+j);
                    
                    a0 = _mm256_broadcast_sd ( a + (l+3)*m + i);
                    a1 = _mm256_broadcast_sd ( a + (l+3)*m + 1 + i);
                    a2 = _mm256_broadcast_sd ( a + (l+3)*m + 2 + i);
                    a3 = _mm256_broadcast_sd ( a + (l+3)*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;
                }
                for(; l<k; l++){
                    register __m256d b_load = _mm256_load_pd(b+l*n+j);

                    register __m256d a0 = _mm256_broadcast_sd ( a + l*m + i);
                    register __m256d a1 = _mm256_broadcast_sd ( a + l*m + 1 + i);
                    register __m256d a2 = _mm256_broadcast_sd ( a + l*m + 2 + i);
                    register __m256d a3 = _mm256_broadcast_sd ( a + l*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;
                }
                _mm256_storeu_pd(c+(0+i)*ldc+j, c0);
                _mm256_storeu_pd(c+(1+i)*ldc+j, c1);
                _mm256_storeu_pd(c+(2+i)*ldc+j, c2);
                _mm256_storeu_pd(c+(3+i)*ldc+j, c3);
            }
        }
    }else{
        for ( l = 0; l < k; ++l )
        {                 
            for ( j = 0; j < n; ++j )
            { 
                for ( i = 0; i < m; ++i )
                { 
                    c(i, j, ldc) += a(l, i, DGEMM_MR) * b(l, j, DGEMM_NR);
                }
            }
        }
    }

}

void bl_dgemm_avx_2_unroll_8( int    k,
		           int    m,
                   int    n,
                   const double * restrict a,
                   const double * restrict b,
                   double *c,
                   unsigned long long ldc,
                   aux_t* data )
{
    int l, j, i;
    if(m == DGEMM_MR && n == DGEMM_NR){
        for(i=0; i<m; i+=4){
            for(j=0; j<n; j+=4){
                register __m256d c0 = _mm256_loadu_pd(c+(0+i)*ldc+j);
                register __m256d c1 = _mm256_loadu_pd(c+(1+i)*ldc+j);
                register __m256d c2 = _mm256_loadu_pd(c+(2+i)*ldc+j);
                register __m256d c3 = _mm256_loadu_pd(c+(3+i)*ldc+j);
                for(l=0; l<=k-8; l+=8){
                    register __m256d b_load = _mm256_load_pd(b+l*n+j);

                    register __m256d a0 = _mm256_broadcast_sd ( a + l*m + i);
                    register __m256d a1 = _mm256_broadcast_sd ( a + l*m + 1 + i);
                    register __m256d a2 = _mm256_broadcast_sd ( a + l*m + 2 + i);
                    register __m256d a3 = _mm256_broadcast_sd ( a + l*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;

                    // register __m256d b_load_2 = _mm256_load_pd(b+(l+1)*n+j);
                    
                    // register __m256d a0_2 = _mm256_broadcast_sd ( a + (l+1)*m + i);
                    // register __m256d a1_2 = _mm256_broadcast_sd ( a + (l+1)*m + 1 + i);
                    // register __m256d a2_2 = _mm256_broadcast_sd ( a + (l+1)*m + 2 + i);
                    // register __m256d a3_2 = _mm256_broadcast_sd ( a + (l+1)*m + 3 + i);

                    // c0 += a0_2*b_load_2;
                    // c1 += a1_2*b_load_2;
                    // c2 += a2_2*b_load_2;
                    // c3 += a3_2*b_load_2;

                    b_load = _mm256_load_pd(b+(l+1)*n+j);
                    
                    a0 = _mm256_broadcast_sd ( a + (l+1)*m + i);
                    a1 = _mm256_broadcast_sd ( a + (l+1)*m + 1 + i);
                    a2 = _mm256_broadcast_sd ( a + (l+1)*m + 2 + i);
                    a3 = _mm256_broadcast_sd ( a + (l+1)*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;

                    b_load = _mm256_load_pd(b+(l+2)*n+j);
                    
                    a0 = _mm256_broadcast_sd ( a + (l+2)*m + i);
                    a1 = _mm256_broadcast_sd ( a + (l+2)*m + 1 + i);
                    a2 = _mm256_broadcast_sd ( a + (l+2)*m + 2 + i);
                    a3 = _mm256_broadcast_sd ( a + (l+2)*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;

                    b_load = _mm256_load_pd(b+(l+3)*n+j);
                    
                    a0 = _mm256_broadcast_sd ( a + (l+3)*m + i);
                    a1 = _mm256_broadcast_sd ( a + (l+3)*m + 1 + i);
                    a2 = _mm256_broadcast_sd ( a + (l+3)*m + 2 + i);
                    a3 = _mm256_broadcast_sd ( a + (l+3)*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;

                    b_load = _mm256_load_pd(b+(l+4)*n+j);
                    
                    a0 = _mm256_broadcast_sd ( a + (l+4)*m + i);
                    a1 = _mm256_broadcast_sd ( a + (l+4)*m + 1 + i);
                    a2 = _mm256_broadcast_sd ( a + (l+4)*m + 2 + i);
                    a3 = _mm256_broadcast_sd ( a + (l+4)*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;

                    b_load = _mm256_load_pd(b+(l+5)*n+j);
                    
                    a0 = _mm256_broadcast_sd ( a + (l+5)*m + i);
                    a1 = _mm256_broadcast_sd ( a + (l+5)*m + 1 + i);
                    a2 = _mm256_broadcast_sd ( a + (l+5)*m + 2 + i);
                    a3 = _mm256_broadcast_sd ( a + (l+5)*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;

                    b_load = _mm256_load_pd(b+(l+6)*n+j);
                    
                    a0 = _mm256_broadcast_sd ( a + (l+6)*m + i);
                    a1 = _mm256_broadcast_sd ( a + (l+6)*m + 1 + i);
                    a2 = _mm256_broadcast_sd ( a + (l+6)*m + 2 + i);
                    a3 = _mm256_broadcast_sd ( a + (l+6)*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;

                    b_load = _mm256_load_pd(b+(l+7)*n+j);
                    
                    a0 = _mm256_broadcast_sd ( a + (l+7)*m + i);
                    a1 = _mm256_broadcast_sd ( a + (l+7)*m + 1 + i);
                    a2 = _mm256_broadcast_sd ( a + (l+7)*m + 2 + i);
                    a3 = _mm256_broadcast_sd ( a + (l+7)*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;
                }
                for(; l<k; l++){
                    register __m256d b_load = _mm256_load_pd(b+l*n+j);

                    register __m256d a0 = _mm256_broadcast_sd ( a + l*m + i);
                    register __m256d a1 = _mm256_broadcast_sd ( a + l*m + 1 + i);
                    register __m256d a2 = _mm256_broadcast_sd ( a + l*m + 2 + i);
                    register __m256d a3 = _mm256_broadcast_sd ( a + l*m + 3 + i);

                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;
                }
                _mm256_storeu_pd(c+(0+i)*ldc+j, c0);
                _mm256_storeu_pd(c+(1+i)*ldc+j, c1);
                _mm256_storeu_pd(c+(2+i)*ldc+j, c2);
                _mm256_storeu_pd(c+(3+i)*ldc+j, c3);
            }
        }
    }else{
        for ( l = 0; l < k; ++l )
        {                 
            for ( j = 0; j < n; ++j )
            { 
                for ( i = 0; i < m; ++i )
                { 
                    c(i, j, ldc) += a(l, i, DGEMM_MR) * b(l, j, DGEMM_NR);
                }
            }
        }
    }

}

void bl_dgemm_avx_3( int    k,
		           int    m,
                   int    n,
                   const double * restrict a,
                   const double * restrict b,
                   double *c,
                   unsigned long long ldc,
                   aux_t* data )
{
    int l, j, i;
    if(m == DGEMM_MR && n == DGEMM_NR){
        for(i=0; i<m; i+=8){
            for(j=0; j<n; j+=4){
                register __m256d c0 = _mm256_loadu_pd(c+(0+i)*ldc+j);
                register __m256d c1 = _mm256_loadu_pd(c+(1+i)*ldc+j);
                register __m256d c2 = _mm256_loadu_pd(c+(2+i)*ldc+j);
                register __m256d c3 = _mm256_loadu_pd(c+(3+i)*ldc+j);
                register __m256d c4 = _mm256_loadu_pd(c+(4+i)*ldc+j);
                register __m256d c5 = _mm256_loadu_pd(c+(5+i)*ldc+j);
                register __m256d c6 = _mm256_loadu_pd(c+(6+i)*ldc+j);
                register __m256d c7 = _mm256_loadu_pd(c+(7+i)*ldc+j);
                for(l=0; l<k; l++){
                    register __m256d b_load = _mm256_load_pd(b+l*n+j);
                    register __m256d a0 = _mm256_broadcast_sd ( a + l*m + i);
                    register __m256d a1 = _mm256_broadcast_sd ( a + l*m + 1 + i);
                    register __m256d a2 = _mm256_broadcast_sd ( a + l*m + 2 + i);
                    register __m256d a3 = _mm256_broadcast_sd ( a + l*m + 3 + i);
                    register __m256d a4 = _mm256_broadcast_sd ( a + l*m + 4 + i);
                    register __m256d a5 = _mm256_broadcast_sd ( a + l*m + 5 + i);
                    register __m256d a6 = _mm256_broadcast_sd ( a + l*m + 6 + i);
                    register __m256d a7 = _mm256_broadcast_sd ( a + l*m + 7 + i);

                    // c0 = _mm256_fmadd_pd(a0, b_load, c0);
                    // c1 = _mm256_fmadd_pd(a1, b_load, c1);
                    // c2 = _mm256_fmadd_pd(a2, b_load, c2);
                    // c3 = _mm256_fmadd_pd(a3, b_load, c3);
                    c0 += a0*b_load;
                    c1 += a1*b_load;
                    c2 += a2*b_load;
                    c3 += a3*b_load;
                    c4 += a4*b_load;
                    c5 += a5*b_load;
                    c6 += a6*b_load;
                    c7 += a7*b_load;
                }
                _mm256_storeu_pd(c+(0+i)*ldc+j, c0);
                _mm256_storeu_pd(c+(1+i)*ldc+j, c1);
                _mm256_storeu_pd(c+(2+i)*ldc+j, c2);
                _mm256_storeu_pd(c+(3+i)*ldc+j, c3);
                _mm256_storeu_pd(c+(4+i)*ldc+j, c4);
                _mm256_storeu_pd(c+(5+i)*ldc+j, c5);
                _mm256_storeu_pd(c+(6+i)*ldc+j, c6);
                _mm256_storeu_pd(c+(7+i)*ldc+j, c7);
            }
        }
    }else{
        for ( l = 0; l < k; ++l )
        {                 
            for ( j = 0; j < n; ++j )
            { 
                for ( i = 0; i < m; ++i )
                { 
                    c(i, j, ldc) += a(l, i, DGEMM_MR) * b(l, j, DGEMM_NR);
                }
            }
        }
    }

}