[Math-atlas-commits] CVS: AtlasBase/Clint atlas-lvl2.base, 1.32, 1.33
Brought to you by:
rwhaley,
tonyc040457
From: R. C. W. <rw...@us...> - 2009-04-27 23:06:45
|
Update of /cvsroot/math-atlas/AtlasBase/Clint In directory 23jxhf1.ch3.sourceforge.com:/tmp/cvs-serv18162/Clint Modified Files: atlas-lvl2.base Log Message: Index: atlas-lvl2.base =================================================================== RCS file: /cvsroot/math-atlas/AtlasBase/Clint/atlas-lvl2.base,v retrieving revision 1.32 retrieving revision 1.33 diff -C2 -d -r1.32 -r1.33 *** atlas-lvl2.base 27 Apr 2009 16:31:56 -0000 1.32 --- atlas-lvl2.base 27 Apr 2009 23:06:24 -0000 1.33 *************** *** 11306,11309 **** --- 11306,11354 ---- #define MISALIGN 2 + void gen_sMUxNU( + FILE *fpout, /* file to print to */ + char *spc, /* string with indentation spaces */ + int mu, /* unrolling on M dimension */ + int nu, /* unrolling on N dimension */ + int I0) /* what to start I at (for unrolling w/o register block) */ + { + int i, j; + char **ldas; + + ldas = malloc(nu*sizeof(char*)); + assert(ldas); + ldas[0] = malloc(sizeof(char)); + assert(ldas[0]); + ldas[0] = '\0'; + for (j=1; j < nu; j++) + { + ldas[i] = malloc(8*sizeof(char)); + assert(ldas[i]); + ldas[i][0] = 'l'; + ldas[i][1] = 'd'; + ldas[i][2] = 'a'; + sprintf(lda[i]+3, "%d", j); + } + + for (j=0; j < nu; j++) + { + for (i=0; i < mu; i++) + { + if (!j) + fprintf(fpout, "%sx%d = _mm_load_sd(X+i+%d);\n", spc, i, i+I0); + fprintf(fpout, "%sa%d_%d = _mm_load_sd(A+i+%d+%s);\n", + spc, i, j, i+I0, ldas[j]); + fprintf(fpout, "%sm%d_%d = _mm_mul_sd(x%d, y%d);\n", + spc, i, j, i, j); + fprintf(fpout, "%sa%d_%d = _mm_add_sd(a%d_%d, m%d_%d\n", + spc, i, j, i, j, i, j); + fprintf(fpout, "%s_mm_store_sd(A+i+%d+%s, a%d_%d);\n", + spc, i+I0, ldas[j], i, j); + } + } + for (j=0; j < nu; j++) + free(ldas[j]); + free(ldas); + } void gen_MUxNU( FILE *fpout, /* file to print to */ *************** *** 11353,11356 **** --- 11398,11528 ---- } + void genMloop( + FILE *fpout, /* file to print to */ + char *spc, /* string with indentation spaces */ + char *mstart,/* usually "0", but will be "M8" for UR=8 cleanup */ + char *mend, /* ending clause for i */ + int nmu, /* # of reps of mu register block in M unroll */ + int mu, /* unrolling on M dimension */ + int nu) /* unrolling on N dimension */ + { + int i, j; + int MU = nmu * mu; + + fprintf(fpout, "%sfor (i=%s; i < %s; i += %d)\n", spc, mstart, mend, MU, MU); + fprintf(fpout, "%s{/* ----- BEGIN M-LOOP BODY ----- */\n", spc); + for (i=0; i < nmu; i++) + { + fprintf(fpout, "%s /* --- BEGIN MUxNU UNROLL %d --- */\n", i); + gen_MUxNU(fpout, spc-3, mu, nu, i*mu); + fprintf(fpout, "%s /* --- END MUxNU UNROLL %d --- */\n", i); + } + fprintf(fpout, "%s}/* ----- END M-LOOP BODY ----- */\n", spc); + } + + void genNloopBody( + FILE *fpout, /* file to print to */ + char *spc, /* string with indentation spaces */ + int nmu, /* # of reps of mu register block in M unroll */ + int mu, /* unrolling on M dimension */ + int nu) /* unrolling on N dimension */ + { + char *mstart[8], mend[8]; + int MU = nmu*mu; + + sprintf(mstart, "M%d", MU); + fprintf(fpout, "%sy0 = _mm_load1_pd(Y);\n", spc); + for (j=1; j < nu; j++) + fprintf(fpout, "%sy%d = _mm_load1_pd(Y+incY%d);\n", spc, j); + genMloop(fpout, spc, "0", mstart, nmu, mu, nu); + if (MU > 2) + { + fprintf(fpout, "%sif (M != M%d)\n", spc, mu); + fprintf(fpout,"%s{/* ----- BEGIN VECTOR UNROLL M CLEANUP ----- */\n",spc); + spc -= 3; + // if (mu > 6) + genMloop(fpout, spc, mstart, "M2", 1, 2, nu); + // else /* put case statement in here eventually */ + // { + // gen_MUxNU(fpout, spc-3, mu, nu, i*mu); + // } + } + fprintf(fpout, "%sif (M != M2)\n", spc); + fprintf(fpout, "%s{/* ----- BEGIN SCALAR M CLEANUP ----- */\n", spc); + gen_sMUxNU(fpout, spc-3, 1, nu, 0); + fprintf(fpout, "%s}/* ----- END SCALAR M CLEANUP ----- */\n", spc); + + if (MU > 2) + { + spc += 3; + fprintf(fpout, "%s}/* ----- END VECTOR UNROLL M CLEANUP ----- */\n", spc); + } + } + + void genNloop( + FILE *fpout, /* file to print to */ + char *spc, /* string with indentation spaces */ + char *nstart,/* usually "0", but will be "M8" for UR=8 cleanup */ + char *nend, /* ending clause for j */ + int nmu, /* # of reps of mu register block in M unroll */ + int mu, /* unrolling on M dimension */ + int nu) /* unrolling on N dimension */ + { + fprintf(fpout, "%sfor (j=%s; j < %s; j += %d, A += lda%d, Y += incY%d)\n", + spc, nstart, nend, nu, nu, nu); + fprintf(fpout, "%s{/* BEGIN N-LOOP UR=%d */\n", spc, nu); + genNloopBody(fpout, spc-3, nmu, mu, nu); + fprintf(fpout, "%s}/* END N-LOOP UR=%d */\n", spc, nu); + } + + void genR1( + FILE *fpout, /* file to print to */ + char *spc, /* string with indentation spaces */ + char *rout, /* routine name */ + int nmu, /* # of reps of mu register block in M unroll */ + int mu, /* unrolling on M dimension */ + int nu) /* unrolling on N dimension */ + { + int i, j; + char nbnd[8]; + + fprintf(fpout, "%svoid %s\n", spc, rout); + fprintf(fpout, "%s (ATL_CINT M, ATL_CINT N, const SCALAR alpha,\n", spc); + fprintf(fpout, "%s const TYPE *X, ATL_CINT incX, const TYPE *Y,\n", spc); + fprintf(fpout, "%s ATL_CINT incY1, TYPE *A, ATL_CINT lda1)\n", spc); + fprintf(fpout, "%s{/* BEGIN GER: nMU=%d, MU=%d, NU=%d */\n", nmu, mu, nu); + spc -= 3; + fprintf(fpout, "%sATL_INT i, j\n", spc); + fprintf(fpout, "%sATL_CINT", spc); + fprintf(fpout, " M%d=((M/%d)*%d)", mu, mu, mu); + fprintf(fpout, ", M2=((M>>1)<<1)"); + fprintf(fpout, ", N%d=((N/%d)*%d)", nu, nu, nu); + for (j=1; j < nu; j++) + { + fprintf(fpout, ", lda%d=lda%d+lda", j, j-1); + fprintf(fpout, ", incY%d=incY%d+incY", j, j-1); + } + fprintf(fpout, ";\n"); + fprintf(fpout, "%s__m128d x0"); + for (j=1; j < mu; j++) + fprintf(fpout, ", x%d", j); + for (j=0; j < nu; j++) + fprintf(fpout, ", y%d", j); + for (j=0; j < nu; j++) + for (i=0; i < mu; i++) + fprintf(fpout, ", a%d_%d", i, j); + fprintf(fpout, ";\n"); + + sprintf(nbnd, "N%d", nu); + genNloop(fpout, spc, "0", nbnd, nmu, mu, nu); + if (nu > 1) /* later on, use case statement rather than NU=1 cleanup */ + { + genNloop(fpout, spc, nbnd, "N", nmu, mu, 1); + } + spc += 3; + fprintf(fpout, "%s}/* END GER: nMU=%d, MU=%d, NU=%d */\n", nmu, mu, nu); + + } + void PrintUsage(char *name, char *arg, int i) { *************** *** 11358,11364 **** fprintf(stderr, "BAD ARG '%s' on %dth FLAG\n", arg, i); fprintf(stderr, "USAGE: %s [flags], where flags are:\n", name); ! fprintf(stderr, " -mu <#> : unroll (wt reg blking) M loop by #\n"); ! fprintf(stderr, " -m# <#> : repeat mu unroll # times in loop\n"); ! fprintf(stderr, " -nu <#> : unroll&jam N loop by #\n"); fprintf(stderr, " -A 0 : both A & X are aligned to 16 byte boundary\n"); fprintf(stderr, " -A 1 : lda is even, A & X are aligned to same boundary\n"); --- 11530,11536 ---- fprintf(stderr, "BAD ARG '%s' on %dth FLAG\n", arg, i); fprintf(stderr, "USAGE: %s [flags], where flags are:\n", name); ! fprintf(stderr, " -M <#> : repeat mu unroll # times in loop\n"); ! fprintf(stderr, " -m <#> : unroll (wt reg blking) M loop by #\n"); ! fprintf(stderr, " -n <#> : unroll&jam N loop by #\n"); fprintf(stderr, " -A 0 : both A & X are aligned to 16 byte boundary\n"); fprintf(stderr, " -A 1 : lda is even, A & X are aligned to same boundary\n"); *************** *** 11366,11367 **** --- 11538,11589 ---- " -A 2 : X really **X, the first array is 16-byte aligned, 2nd 8-byte\n"); } + + int GetFlags(int nargs, char **args, int *NMU, int *MU, int *NU) + { + int flag=0; + + *NMU = *MU = 1; + *NU = 2; + + for (i=1; i < nargs; i++) + { + if (args[i][0] != '-') + PrintUsage(args[0], "No '-' preceeding flag!", i); + switch(args[i][1]) + { + case 'M': + if (++i >= nargs) + PrintUsage(args[0], "out of flags in -M ", i-1); + *NMU = atoi(args[i]); + break; + case 'm': + if (++i >= nargs) + PrintUsage(args[0], "out of flags in -m ", i-1); + *MU = atoi(args[i]); + break; + case 'n': + if (++i >= nargs) + PrintUsage(args[0], "out of flags in -n ", i-1); + *NU = atoi(args[i]); + break; + default: + PrintUsage(args[0], args[i], i); + } + } + return(flag); + } + + #define NSPCS 128 + int main(int nargs, char **args) + { + int nmu, mu, nu, flag; + char spc[NSPCS]; + + for (i=0; i < NSPCS; i++) + spc[i] = ' '; + spc[NSPCS-1] = '\0'; + spc += NSPCS-4; + flag = GetFlags(nargs, args, &nmu, &mu, &nu); + genR1(stdout, spc, "Mjoin(PATL,ger1_a1_x1_yX)", nmu, mu, nu); + exit(0); + } |