完整的karatsuba乘法,二进制的,完全适用于10进制,只要改写其中用到的函数
yaos 2004-07-24 01:51:17 4// multest.cpp : 定义控制台应用程序的入口点。
//
//Intel P4 SSE2 修改的版本
#include "stdafx.h"
unsigned long a[65536];
unsigned long b[65536];
unsigned long c[131072];
//双字串乘法,已经调试完成,速度为 < 12 clock /DWORD
void AsmMulLL(unsigned long *pL, unsigned long *pR, unsigned long *pA, unsigned long tL,
unsigned long tR)
{
if ((tL == 0) || (tR == 0))
return;
__asm {
mov ecx, tL
mov esi, dword ptr [pL]
mov edi, dword ptr [pR]
mov ebx, dword ptr [pA]
pxor mm3, mm3
mbinmul2:
mov edx, ecx
mov eax, ebx
pxor mm0, mm0
mov ecx, tR
movd mm1, dword ptr [esi]
movd mm4, edi
mbinmul3:
movd mm2, dword ptr [edi]
lea edi, [edi+4]
pmuludq mm2, mm1
movd mm3, dword ptr [ebx]
paddq mm0, mm3
paddq mm0, mm2
movd dword ptr [ebx], mm0
psrlq mm0, 32
lea ebx, [ebx+4]
loop mbinmul3
movd edi, mm4
movd dword ptr [ebx], mm0
mov ebx, eax
lea esi, [esi+4]
lea ebx, [ebx+4]
mov ecx, edx
loop mbinmul2
emms
}
}
//双字串加法,已经调试完成
unsigned long AsmAddLL(unsigned long *pL, unsigned long *pR, unsigned long *pA, unsigned
long t)
{
if (t ==0)
return 0;
__asm
{
mov ecx, t
mov esi, dword ptr [pL]
mov edi, dword ptr [pR]
mov ebx, dword ptr [pA]
pxor mm0, mm0
pxor mm1, mm1
pxor mm2, mm2
lea esi, [esi+4*ecx]
lea edi, [edi+4*ecx]
lea ebx, [ebx+4*ecx]
neg ecx
asmadd1:
movd mm1, dword ptr [esi+ecx*4]
movd mm2, dword ptr [edi+ecx*4]
paddq mm0, mm1
paddq mm0, mm2
movd dword ptr [ebx+ecx*4], mm0
psrlq mm0, 32
add ecx, 1
jne asmadd1
movd eax, mm0
emms
}
}
//双字串加双字, 调试完成, 还可以优化 :)
unsigned long AsmAddLS(unsigned long *pL, unsigned long pR, unsigned long *pA, unsigned
long t)
{
if (t ==0)
return 0;
__asm
{
mov ecx, t
mov esi, dword ptr [pL]
mov ebx, dword ptr [pA]
pxor mm0, mm0
pxor mm1, mm1
movd mm0, dword ptr [pR]
lea esi, [esi+4*ecx]
lea ebx, [ebx+4*ecx]
neg ecx
asmadd1:
movd mm1, dword ptr [esi+ecx*4]
paddq mm0, mm1
movd dword ptr [ebx+ecx*4], mm0
psrlq mm0, 32
add ecx, 1
jne asmadd1
movd eax, mm0
emms
}
}
//求负, 调试完成
void AsmNeg(unsigned long * p, unsigned long t)
{
if (t == 0)
return;
__asm
{
//mm1 $FFFFFFFF
//mm0 $100000000
mov eax, dword ptr [p]
mov ecx, t
mov edx, 1
movd mm0, edx
psllq mm0, 32
pxor mm1, mm1
mov edx, 0xFFFFFFFF
movd mm1, edx
pxor mm3, mm3
asmneg0:
movq mm2, mm0
psrlq mm2, 32
paddq mm2, mm1
movd mm3, [eax]
psubq mm2, mm3
movq mm0, mm2
movd [eax], mm2
lea eax, [eax+4]
loop asmneg0
emms
}
}
//双字串减法, 调试完成
unsigned long AsmSubLL(unsigned long * pL, unsigned long * pR, unsigned long * pA, unsigned
long t)
{
if (t == 0)
return 0;
__asm
{
mov ecx, t
mov esi, dword ptr [pL]
mov edi, dword ptr [pR]
mov ebx, dword ptr [pA]
pxor mm0, mm0
pxor mm1, mm1
pxor mm2, mm2
lea esi, [esi+4*ecx]
lea edi, [edi+4*ecx]
lea ebx, [ebx+4*ecx]
neg ecx
asmsub1:
movd mm1, dword ptr [esi+ecx*4]
movd mm2, dword ptr [edi+ecx*4]
psubq mm1, mm2
psubq mm1, mm0
movd dword ptr [ebx+ecx*4], mm1
psrlq mm1, 63
movq mm0, mm1
add ecx, 1
jne asmsub1
movd eax, mm0
emms
}
}
//找到比n大的第一个2的方幂,n不能为2方幂,否则得到的结果要除2 :)暂时不修改
unsigned long find2k(unsigned long n) {
if (n == 0)
return 0;
__asm
{
mov ebx, n
bsr ecx, ebx
inc ecx
mov eax, 1
shl eax, cl
}
}
//双字串清零,提高了速度,达到 < 2 Clock / DWORD
void AsmMemZero(unsigned long * p, unsigned long s)
{
if (s == 0)
return;
__asm
{
mov eax, dword ptr [p]
mov edx, s
mov ebx, edx
shr ebx, 2
mov ecx, ebx
shl ebx, 2
sub edx, ebx
cmp ecx, 0
jz asmmemzero1
pxor mm0, mm0
asmmemzero0:
movq [eax], mm0
movq [eax+8], mm0
add eax, 16
loop asmmemzero0
asmmemzero1:
cmp edx, 0
jz asmmemzero3
mov ecx, edx
mov ebx ,0
asmmemzero2:
mov dword ptr [eax], ebx
add eax, 4
loop asmmemzero2
asmmemzero3:
emms
}
}
//karatsuba核心函数, 调试完成
//u = 2 ^ n * U1 + U0
//v = 2 ^ n * V1 + V0
//uv = (2 ^ 2n + 2 ^ n) U1 * V1 + 2 ^ n (U1 - U0) * (V0 - V1) + (2 ^ n + 1) U0 * V0
void mul_karatsuba_core(unsigned long *pL, unsigned long *pR, unsigned long *pA, unsigned
long t)
{
if (t <= 32)
return AsmMulLL(pL, pR, pA, t, t);
unsigned long * tmp, * tmp1, * tmp2;
long neg = 1, carry = 0;
tmp = new unsigned long [t];
tmp1 = new unsigned long [t / 2];
tmp2 = new unsigned long [t / 2];
AsmMemZero(tmp, t);
AsmMemZero(tmp1, t / 2);
AsmMemZero(tmp2, t / 2);
mul_karatsuba_core(pL, pR, pA, t / 2);
mul_karatsuba_core(pL + t / 2, pR + t / 2, pA + t, t / 2);
if (AsmSubLL(pL + t / 2, pL, tmp1, t / 2) == 1)
{
neg = - neg;
AsmNeg(tmp1, t / 2);
}
if (AsmSubLL(pR, pR + t / 2, tmp2, t / 2) == 1)
{
neg = - neg;
AsmNeg(tmp2, t / 2);
}
mul_karatsuba_core(tmp1, tmp2, tmp, t / 2);
if (neg == - 1)
{
if (AsmSubLL(pA, tmp, tmp, t))
{
AsmNeg(tmp, t);
if (AsmSubLL(pA + t, tmp, tmp, t))
{
AsmNeg(tmp, t);
AsmSubLL(pA + t / 2, tmp, pA + t / 2, t);
}
else
carry = AsmAddLL(pA + t / 2, tmp, pA + t / 2, t);
}
else
{
carry = AsmAddLL(pA + t, tmp, tmp, t);
carry += AsmAddLL(pA + t / 2, tmp, pA + t / 2, t);
}
}
else
{
carry = AsmAddLL(pA, tmp, tmp, t);
carry += AsmAddLL(pA + t, tmp, tmp, t);
carry += AsmAddLL(pA + t / 2, tmp, pA + t / 2, t);
}
if (carry > 0)
AsmAddLS(pA + t / 2 + t, (unsigned long)carry, pA + t / 2 + t, t / 2);
delete []tmp;
delete []tmp1;
delete []tmp2;
}
//重新设计的优化算法
void mul_karatsuba(unsigned long *pL, unsigned long *pR, unsigned long *pA, unsigned long
tL, unsigned long tR)
{
if (tL < tR) //保证第一个乘数更大
{
mul_karatsuba(pR, pL, pA, tR, tL);
return;
}
if (tR <= 32) //常规方法更好
{
AsmMulLL(pL, pR, pA, tL, tR);
return;
}
unsigned long t = tL / tR;
if (t >= 2) //两个数字长度差别太大
{
unsigned long * pT = new unsigned long[2 * tR];
for (unsigned long i = 0; i < t; i ++)
{
AsmMemZero(pT, 2 * tR);
mul_karatsuba(pL + i * tR, pR, pT, tR, tR); //部分乘法
AsmAddLL(pT, pA + i * tR, pA + i * tR, 2 * tR); //加结果到pA, 不应该
有进位
}
unsigned long r = tL % tR;
if (r > 0) //处理余下的部分
{
AsmMemZero(pT, 2 * tR);
mul_karatsuba(pL + tL - r, pR, pT, r, tR);
AsmAddLL(pT, pA + tL - r, pA + tL - r, tR + r);
}
delete []pT;
return;
}
unsigned long l = find2k(tL);
if (l == 2 * tL)
l = l / 2;
unsigned long * src1 = new unsigned long [l];
unsigned long * src2 = new unsigned long [l];
unsigned long * dest = new unsigned long [2 * l];
AsmMemZero(src1, l);
AsmMemZero(src2, l);
AsmMemZero(dest, 2 * l);
memcpy(src1, pL, 4 * tL);
memcpy(src2, pR, 4 * tR);
mul_karatsuba_core(src1, src2, dest, l);
memcpy(pA, dest, (tL + tR) * 4);
delete []src1;
delete []src2;
delete []dest;
}
int _tmain(int argc, _TCHAR* argv[])
{
unsigned long i;
for (i = 0; i < 65536; i ++)
{
a[i] = 0xFFFFFFFF;
b[i] = 0xFFFFFFFF;
c[i] = 0;
}
for (i = 65536; i < 131072; i ++)
c[i] = 0;
mul_karatsuba(a, b, c, 65536, 65536);
return 0;
}