IPB

Welcome Guest ( Log In | Register )

4 Pages V  < 1 2 3 4 >  
Reply to this topicStart new topic
> my speedy SGEMM
vvolkov
post Jan 15 2008, 04:37 AM
Post #41



*****

Group: Extranet Users
Posts: 107
Joined: 6-October 07
From: Berkeley, CA
Member No.: 72,970
Org.: UC Berkeley



Hey Paul,

I've got 185 Gflop/s with your code on GeForce 8800 GTX. I think you can fix up the last iteration without if-statement --- run the iterations up to k-4, then do the last iteration individually.

Vasily
Go to the top of the page
 
+Quote Post
pleventi
post Jan 15 2008, 05:17 AM
Post #42



**

Group: Members
Posts: 14
Joined: 23-October 07
From: Toronto, Canada
Member No.: 75,275
Org.: Altera Corp.



QUOTE(vvolkov @ Jan 15 2008, 12:37 AM)
Hey Paul,

I've got 185 Gflop/s with your code on GeForce 8800 GTX. I think you can fix up the last iteration without if-statement --- run the iterations up to k-4, then do the last iteration individually.

Vasily
[right][snapback]310254[/snapback][/right]


Yup, that's what I meant with my comment if it wasn't clear :-)

Now to try out a few hand-coded kernels in assembly using wumpus' assembler -- I'm really starting to tire of monkeying with the C code to try to influence the compiler + ptxas to do what I want!
Go to the top of the page
 
+Quote Post
vvolkov
post Jan 15 2008, 09:31 AM
Post #43



*****

Group: Extranet Users
Posts: 107
Joined: 6-October 07
From: Berkeley, CA
Member No.: 72,970
Org.: UC Berkeley



I'm sorry, that comment was so long that I didn't manage to read all of it ;-)

I had another >180 Gflop/s code that gives now up to 189 Gflop/s if enhanced with your prefetching. It has less pointer arithmetic by using different layouts for A's and B's blocks in the shared memory. I post it here.

Indeed, these fast codes have much to do with playing catch with the compiler.

CODE
__device__ void saxpy( float a, float *b, float *c )
{
   c[0] += a*b[0];
   c[1] += a*b[1];
   c[2] += a*b[2];
   c[3] += a*b[3];
   c[4] += a*b[4];
   c[5] += a*b[5];
   c[6] += a*b[6];
   c[7] += a*b[7];
   c[8] += a*b[8];
   c[9] += a*b[9];
   c[10] += a*b[10];
   c[11] += a*b[11];
   c[12] += a*b[12];
   c[13] += a*b[13];
   c[14] += a*b[14];
   c[15] += a*b[15];
}

__global__ void sgemmNT( const float *A, int lda, const float *B, int ldb, float* C, int ldc, int k, float alpha, float beta )
{
   int inx = threadIdx.x;
   int iny = threadIdx.y;
   int ibx = blockIdx.x * 32;
   int iby = blockIdx.y * 32;

   A += ibx + inx + __mul24( iny, lda );
   B += iby + (inx%16) + __mul24( inx/16 + iny*2, ldb );
   C += ibx + inx + __mul24( iby + iny*16, ldc );

   const float *A1 = A + 2*lda;
   const float *B1 = B + 16;

   float a1 = A[0];
   float a2 = A1[0];
   float b1 = B[0];
   float b2 = B1[0];

   const float *Blast = B + k*ldb;

   A  += 4*lda;
   A1 += 4*lda;
   B  += 4*ldb;
   B1 += 4*ldb;

   __shared__ float a[160], b[128];
   float c[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};

   float *a0 = a + 5*inx;
   float *b0 = b + 32*iny;
   float *_b = b + 64*iny;

   do
   {
       a0[iny]    = a1;
       a0[iny+2]  = a2;
       b0[inx]    = b1;
       b0[inx+64] = b2;
       __syncthreads();

       a1 = A[0];
       a2 = A1[0];
       b1 = B[0];
       b2 = B1[0];

       saxpy( a0[0], _b,    c );
       saxpy( a0[1], _b+16, c );
       saxpy( a0[2], _b+32, c );
       saxpy( a0[3], _b+48, c );

       A  += 4*lda;
       A1 += 4*lda;
       B  += 4*ldb;
       B1 += 4*ldb;
       __syncthreads();

   } while( B < Blast );

   a0[iny]    = a1;
   a0[iny+2]  = a2;
   b0[inx]    = b1;
   b0[inx+64] = b2;
   __syncthreads();

   saxpy( a0[0], _b,    c );
   saxpy( a0[1], _b+16, c );
   saxpy( a0[2], _b+32, c );
   saxpy( a0[3], _b+48, c );

   for( int i = 0; i < 16; i++, C += ldc )
       C[0] = alpha*c[i] + beta*C[0];
}
Go to the top of the page
 
+Quote Post
pleventi
post Jan 15 2008, 02:21 PM
Post #44



**

Group: Members
Posts: 14
Joined: 23-October 07
From: Toronto, Canada
Member No.: 75,275
Org.: Altera Corp.



I'll give it a whirl and see whether I can eek anything more out of it. So 189 GFLOP/s on a 8800 GTX is the peak now, eh?

A couple quick notes on the previous code I posted:

1) Be sure to compile with -maxrregcount 32. I think it otherwise uses 33 regs and slows down a tad bit (but not much)

2) When you do that, the compiler doesn't allocate registers too well. Transform the for(i=0;i<k;i+=4) loop into for(i=k-4;i>=0;i-=4) instead. This eliminates one register (storing k) and allows the compiler to find a slightly better 32 register solution that results in all four global loads being moved to the top of the barrier region as intended.

3) The compiler still has about 3-5 instructions too many in the resulting code. I've got an assembly version of the code I'll upload tonight as a cubin + loader once I figure out how to get my own cubin going!

- Paul
Go to the top of the page
 
+Quote Post
vvolkov
post Jan 16 2008, 06:37 AM
Post #45



*****

Group: Extranet Users
Posts: 107
Joined: 6-October 07
From: Berkeley, CA
Member No.: 72,970
Org.: UC Berkeley



Paul,

My last version has 86 instructions in the inner loop vs. 92 in yours (I checked that yours uses 32 registers and all global loads are at the top). Can you beat that using assembly? ;-)

By the way, sticking to the 6-cycle model of shared memory MAD, 86-instruction version yields 128*32*16 flops/(64*6 + 22*4 cycles)*1.35 GHz = 187.4 Gflop/s estimate versus 189 Gflop/s that is observed. However, this model may be not that precise at all, since it predicts only 178 Gflop/s for your 185 Gflop/s code.

Vasily
Go to the top of the page
 
+Quote Post
pleventi
post Jan 16 2008, 09:22 AM
Post #46



**

Group: Members
Posts: 14
Joined: 23-October 07
From: Toronto, Canada
Member No.: 75,275
Org.: Altera Corp.



I *think* that the two pairs of add.half instructions are packed into one instruction each or something like that (forget where I read that). I'm not sure whether they issue somehow in parallel too, or if just instruction fetch & decode is improved. That may be a source of misprediction.

My latest code is 86 too (not including hand assembly) -- turns out we both made effectively the same transformations but with very different looking code...

The 86 can be cut down to 85 as follows, if I could just get a assembled cubin to link properly (can't get the kernel to invoke properly using the cuLaunch stuff):

(1) Eliminate use of $ofs4 in second half of loop. Instead, just directly add 0x80 to $ofs2. Then, at the end/start of loop, subtract 0x80 from $ofs2. So we've removed an add and added a subtract (no change).

(2) But we now have $ofs4 free to store the base address of A or B, and can thus eliminate the load from $r30 at the top of the loop. Saves one instruction.

Anyway, enough micro-tweaking -- time to try other approaches to see if we can crack 200. With the global load latency hiding/prefetch trick, it is less important to have multiple blocks in flight which may make more register intensive kernels that look at more data at once worthwhile.

- Paul
Go to the top of the page
 
+Quote Post
langermatze
post Jan 19 2008, 01:38 PM
Post #47



*

Group: Members
Posts: 4
Joined: 9-October 07
From: Germany
Member No.: 73,310
Org.: FAU Erlangen-Nuremberg



Hi!

I've tried nearly the same approach to get better results for cublasCgemm some time ago using CUDA/CUBLAS 1.0. Unfortunately, the resulting timings were much slower than the cublasCgemm.
Now I am really surprised that this approach works that fast for float matrices within CUDA/CUBLAS 1.1. But the benchmarks of my similar approach dealing with complex numbers did not significantly change and it is still slower than the cublasCgemm. I also had no luck to adjust your code to work with complex matrices to outperform cublasCgemm. Have you ever tried it?!
Go to the top of the page
 
+Quote Post
vvolkov
post Jan 20 2008, 10:02 AM
Post #48



*****

Group: Extranet Users
Posts: 107
Joined: 6-October 07
From: Berkeley, CA
Member No.: 72,970
Org.: UC Berkeley



I think that achieving high performance in CGEMM should be easier than doing it in SGEMM, since CGEMM has higher computational intensity (2x more data but 4x more flops). I didn't try it but I guess it won't work well with 32x32 block size in matrix C. Probably, smaller blocks like 32x16 may give better performance. I should note that it took a couple of days to try different block sizes and coding techniques and analyze disassembler outputs until I made my SGEMM work fast.

Thinking of non-square blocks in matrix C I tried some today and finally got 205 Gflop/s with 64x16 block size. This block size allows keeping A's blocks entirely in the registers, thereby avoiding the shared memory overheads (like moving data from registers to the shared memory and back). Also, it reduced consumption of the shared memory, that made it possible to achieve 206 Gflop/s in sgemm( 'N', 'N', ... ). (That in turn improved my LU with partial pivoting to 168 Gflop/s.)
Go to the top of the page
 
+Quote Post
Mark Harris
post Jan 22 2008, 02:09 PM
Post #49



******

Group: Members
Posts: 382
Joined: 15-February 07
From: Brisbane, Australia
Member No.: 40,950
Org.: NVIDIA



Great work guys. Regarding your modeling of the instruction costs...

How many warps are active in your code? I ask because if you have fewer than 6 warps active (192 threads), you won't fully hide the arithmetic pipeline latency with back-to-back dependent arithmetic instructions.

The latency is approximately 22 clocks (this is the 1.35 GHz clock on 8800 GTX), and it takes 4 clocks to execute an arithmetic instruction (ADD, MUL, MAD, etc,) for a whole warp. If you have back to back MAD instructions with shared memory operands, then the hardware instruction set does, as you have surmised, need extra instructions for storing results to shared memory and for preparing shared memory addresses for input operands. So those instructions aren't likely to actually be back to back, so you could possibly hide the latency with 3 warps minimum. However, if you have back-to-back dependent instructions with only register operands they really can be back to back, so:

MAD r3, r1, r1, r2
MAD r1, r3, r2, r4

Then you have to have enough warps to cover the full ~22 clocks latency or you will stall.

This is just one example to show you how a simplistic performance model like "6 cycles per shared memory MAD" can only ever be approximate.

Mark
Go to the top of the page
 
+Quote Post
vvolkov
post Jan 24 2008, 11:21 AM
Post #50



*****

Group: Extranet Users
Posts: 107
Joined: 6-October 07
From: Berkeley, CA
Member No.: 72,970
Org.: UC Berkeley



Mark, thanks for the discussion.
QUOTE(Mark Harris @ Jan 22 2008, 06:09 AM)
How many warps are active in your code?  I ask because if you have fewer than 6 warps active (192 threads), you won't fully hide the arithmetic pipeline latency with back-to-back dependent arithmetic instructions.[right][snapback]313292[/snapback][/right]
In sgemm I have 2 warps per thread block and up to 4 thread blocks per multi-processor. So, it should be enough. Also, all MADs come in blocks of 16 independent instructions, which means there should be no back-to-back dependencies (only RARs, which are usually safe).

I assume that there is no penaly for outputting to the same registers that are used as input in the same instruction, like in MAD r1, r1, r2, r3.

Also, in the microbenchmark that I cited, I have same dependencies in the code that uses the shared memory and in the code that does not. Both codes are run with 256 threads per thread block. Still they get 229 vs. 338 Gflop/s correspondingly.

QUOTE(Mark Harris @ Jan 22 2008, 06:09 AM)
If you have back to back MAD instructions with shared memory operands, then the hardware instruction set does, as you have surmised, need extra instructions for storing results to shared memory and for preparing shared memory addresses for input operands.
[right][snapback]313292[/snapback][/right]
But I don't write to the shared memory in MADs. I have writes that store the data I fetch from the global memory, but there are only a few of them per dozens of MADs. Also, I count all of the extra instructions used for preparing the address (that I can see with decuda) when doing the estimates.

I'd like to see a code that achieves >230 Gflop/s in shared memory MADs. Under "shared memory MADs" I mean instructions like "mad.rn.f32 $r1, s[0x0010], $r2, $r3" if using the decuda notation.

I attached the code that runs at up to >200 Gflop/s on GeForce 8800 GTX in both sgemm( 'N', 'N', ... ) and sgemm( 'N', 'T', ... ).
Attached File(s)
Attached File  sgemmN_012408.zip ( 4.7K ) Number of downloads: 535
 
Go to the top of the page
 
+Quote Post
Mark Harris
post Jan 24 2008, 04:34 PM
Post #51



******

Group: Members
Posts: 382
Joined: 15-February 07
From: Brisbane, Australia
Member No.: 40,950
Org.: NVIDIA



QUOTE(vvolkov @ Jan 24 2008, 12:21 PM)
But I don't write to the shared memory in MADs. I have writes that store the data I fetch from the global memory, but there are only a few of them per dozens of MADs. Also, I count all of the extra instructions used for preparing the address (that I can see with decuda) when doing the estimates.
[right][snapback]314014[/snapback][/right]

Well, like I said: smile.gif
QUOTE(Mark Harris)
If you have ... instructions with shared memory operands, then the hardware instruction set does, as you have surmised, need extra instructions for storing results to shared memory and for preparing shared memory addresses for input operands.

So I believe that in your code:
CODE
c[0] += a*b[0];
c[1] += a*b[1];
c[2] += a*b[2];
c[3] += a*b[3];
c[4] += a*b[4];
c[5] += a*b[5];
c[6] += a*b[6];
c[7] += a*b[7];
c[8] += a*b[8];
// etc.

Each of those lines generates two instructions (ignoring possible store to c[] depending on where the compiler has put it): the mad and an address update. You should be able to see this in the decuda output, right? Is that enough to explain why perf is lower with shared memory operands?

Mark
Go to the top of the page
 
+Quote Post
mfatica
post Jan 24 2008, 04:38 PM
Post #52



*******

Group: NVIDIA Employees
Posts: 882
Joined: 16-February 07
From: Santa Clara, CA
Member No.: 41,162
Org.: NVIDIA



Vasily and Paul,
great work on SGEMM!!!

Massimiliano
Go to the top of the page
 
+Quote Post
pleventi
post Jan 24 2008, 06:09 PM
Post #53



**

Group: Members
Posts: 14
Joined: 23-October 07
From: Toronto, Canada
Member No.: 75,275
Org.: Altera Corp.



Hi Mark,

According to the disassembly output, there is only ~*one* instruction per MAD from the saxpy code. The c[i] values are each in their own register, the a value is a register, and the b[i] reference into a shared memory by a direct offset off an $ofs register. If the b[i] values were seperated by more 0x80 bytes, then the $ofs register would need to be updated between each MAD, which would introduce additional instructions -- but that's not the case here. The code has been very carefully structured by Vasily to ensure that (a) most operands are in registers and (cool.gif that nearly all references into the shared memory do not require a load/update of the $ofs register.

The result is an inner kernel (for previous versions of the code) with 64 MAD instructions of the form:

load r16 with a[0]
MAD r0, s[$ofs1 + 0x00], r16, r0
MAD r1, s[$ofs1 + 0x04], r16, r1
MAD r2, s[$ofs1 + 0x08], r16, r2
...
MAD r15, s[$ofs1 + 0x3C], r16, r15
load r17 with a[1]
MAD r0, s[$ofs1 + 0x40], r17, r0
MAD r1, s[$ofs1 + 0x44], r17, r1
...
increment $ofs1 by 0x80
load r18 with a[2]
MAD r0, s[$ofs1 + 0x00], r18, r0
etc.

Or something like that (don't have the code handy right now). So there are no dependent reads/writes from one MAD to the next. Nor are there any $ofs updates. It just seems like any MAD with a shared memory operand, regardless of # of active warps or # of blocks, has a 33% lower throughput than one using a register operand.

If you'd like, I can post a simple kernel that demonstrates the reduction in throughput as a function of # of threads. That kernel shows the 22 cycle latency and demonstrates that indeed the processor can fully hide the 22 cycle latency after ~192 threads, but that the shared-memory MAD is capped at 66% of the throughput of an all-register MAD.

Regards,

- Paul
Go to the top of the page
 
+Quote Post
julien38
post Jan 29 2008, 02:38 PM
Post #54



*

Group: Members
Posts: 5
Joined: 10-November 07
Member No.: 77,987




I got 205 Gflop/s on my 8800GTX w00t.gif !!
(vs 115 for CUBLAS)

Great work !

Maybe one of you can answer this question:
My application requires matrix-vector products: C=A*B with:
C[N,1], A[N,2N] and B[2N,1]

excl.gif important: A is constant, so host->device transfer is not that critical

My dream is to reach 50microsec with N=1000,
but cublas SGEMV is ~10 times too slow

do you think it is possible to build an optimized sgemv taking advantage of
the fact that A is constant ?

cheers,
J







Go to the top of the page
 
+Quote Post
vvolkov
post Jan 29 2008, 07:05 PM
Post #55



*****

Group: Extranet Users
Posts: 107
Joined: 6-October 07
From: Berkeley, CA
Member No.: 72,970
Org.: UC Berkeley



QUOTE(julien38 @ Jan 29 2008, 06:38 AM)
My application requires matrix-vector products: C=A*B with:
  C[N,1], A[N,2N] and B[2N,1]

(!) important: A is constant, so host->device transfer is not that critical

My dream is to reach 50microsec with N=1000,
but cublas SGEMV is ~10 times too slow
[right][snapback]316746[/snapback][/right]
I get about 15 microsec latency in memcpy to/from the GPU memory, that gives you 30 us overhead for transferring B and C only. Then you need to read A (8000000 bytes) from the GPU memory at least once, which is 114 us if reading at 70 GB/s, the max sustained read bandwidth on this card. So, 114+30 ~ 150 microseconds should be a more realistic goal.

However, getting even this number may be tricky. A naive implementation would create 1000 threads and lose, because this may be not enough to hide the global memory latency, giving you only a fraction of those 70 GB/s. The solution I'd think of is splitting A into block columns of size N x K. Then you do a few SGEMVs to get matrix C[N,N/K]. You transfer it to the CPU and finish the reduction there to get C[N,1]. The catch is that you can do those few SGEMVs at the same time, so that you have enough threads to hide the latency.
Go to the top of the page
 
+Quote Post
Sarnath
post Jan 31 2008, 08:02 AM
Post #56



********

Group: Members
Posts: 1,567
Joined: 23-November 07
From: Bangalore
Member No.: 79,873
Org.: HCL Technologies



Paul,

First of all, Congrats on your speedup!!! Its good to know.

I know nothing about SGEMM. But still, I had a look at your code.

I have a few suggestions. Pardon me, if my ideas r in-appropriate for SGEMM.

1) When you have multi-dimensional blocks (16 x 4, for instance), the way threads are mapped into WARPs might NOT be straightforward. There have been some experiments that have been conducted which shows this clearly.

See the 2 FORUM entries:
http://forums.nvidia.com/index.php?showtopic=58052
http://forums.nvidia.com/index.php?showtopic=57779 A link mentioned in the URL above

2)
In my experience, I have found that having just 32 threads per block eliminates need for __syncthreads(), double-buffering and so on. As long as you schedule enough blocks (6 blocks per multiprocessor), you should be cool. Usually, it results in a 1.3x speedup.
Check out: http://forums.nvidia.com/index.php?showtopic=54875

But I am not sure if the second idea would help you. If your code is dependent on the block dimensions then it may not be possible to re-code.


--------------------
Ignorance Rules; Knowledge Liberates!
Go to the top of the page
 
+Quote Post
pkeir
post Apr 3 2008, 11:36 AM
Post #57



***

Group: Members
Posts: 47
Joined: 17-February 07
Member No.: 41,280
Org.: University of Glasgow



Hi,

I downloaded sgemmN_012408.zip but only managed ~27Gflops/sec for your SGEMM, versus ~120Gflops/sec for the BLAS version. There is zero error/difference for each run. I'm using version 1.0 of the SDK, and have an 8800GTX. Where have I gone wrong?

Kind Regards,
Paul
Go to the top of the page
 
+Quote Post
vvolkov
post Apr 3 2008, 11:43 AM
Post #58



*****

Group: Extranet Users
Posts: 107
Joined: 6-October 07
From: Berkeley, CA
Member No.: 72,970
Org.: UC Berkeley



QUOTE(pkeir @ Apr 3 2008, 03:36 AM)
I downloaded sgemmN_012408.zip but only managed ~27Gflops/sec for your SGEMM, versus ~120Gflops/sec for the BLAS version. I'm using version 1.0 of the SDK, and have an 8800GTX. Where have I gone wrong?
[right][snapback]355897[/snapback][/right]


Try using SDK 1.1. Its performance is strongly compiler-dependent.

Also, I'd recommend compiling it into 32-bit if you are using 64-bit system. Otherwise you might lose ~10% of performance.

Vasily
Go to the top of the page
 
+Quote Post
pkeir
post Apr 7 2008, 09:18 AM
Post #59



***

Group: Members
Posts: 47
Joined: 17-February 07
Member No.: 41,280
Org.: University of Glasgow



Thanks, of course 1.1 kills all my cygwin bash scripts, but hey: 206 Gflops on my 8800GTX. Very impressive!
Go to the top of the page
 
+Quote Post
vvolkov
post May 2 2008, 01:44 PM
Post #60



*****

Group: Extranet Users
Posts: 107
Joined: 6-October 07
From: Berkeley, CA
Member No.: 72,970
Org.: UC Berkeley



I found that ssyrk in CUBLAS 2.0 runs on 8800GTX at up to 127 Gflop/s only. Apparently it is based on the old sgemm. Here is an example that runs as fast as the new sgemm, i.e. at up to 200+ Gflop/s.
Attached File(s)
Attached File  ssyrkLN_050208.zip ( 5.32K ) Number of downloads: 262
 
Go to the top of the page
 
+Quote Post

4 Pages V  < 1 2 3 4 >
Reply to this topicStart new topic

 



Copyright 2008 NVIDIA Corporation.  Terms of Use | Legal Info | Privacy Policy Time is now: 23rd November 2009 - 03:08 PM
Unites States Argentina Brazil Chile China Colombia France Germany India Italy Japan Korea Mexico Poland Russia Spain Taiwan United Kingdom Venezuela