我的MTL矩阵乘法为什么比Java还慢?

Dream_soft 2005-04-16 04:13:04


我不懂MTL,为了做课题写计算程序评估了一下,结果出来的速度比我用Java的Colt包还慢。我相信这是我自己程序的问题,但不知道问题出在那里。
请高手指点,谢谢。

目的:矩阵与向量的积

实现一:matrix与dense1D

#include "stdafx.h"
#include "mtl/matrix.h"
#include "mtl/dense1D.h"
#include "mtl/mtl.h"
#include <atltime.h>

using namespace ATL;
using namespace std;
using namespace mtl;

typedef matrix<double, rectangle<>, dense<>,row_major>::type Matrix;

void Mytest()
{
const Matrix::size_type MAX_ROW = 100, MAX_COL = 100, TestA_ROW =1 , TestA_COL =100;

Matrix TestArray(MAX_ROW,MAX_COL);
dense1D<double> TestA(TestA_COL), Result(TestA_COL);

for (Matrix::size_type i=0; i<MAX_ROW; ++i)
{
for (Matrix::size_type j=0; j<MAX_COL; ++j)
{
TestArray(i, j) = Matrix::value_type(j);
}
}

Matrix::size_type p;
for (p = 0; p < TestA.size(); p++)
TestA[p] = p;

CTime BeginTime = CTime::GetCurrentTime();
for(int j = 0; j< 200000; j++)
{
mult(TestArray,TestA,Result);
}
CTime EndTime = CTime::GetCurrentTime();
CTimeSpan SpendTime = EndTime - BeginTime;
printf("花费 %ld 秒",SpendTime.GetTotalSeconds());
}




int _tmain(int argc, _TCHAR* argv[])
{

for(int i = 0; i<=5; i++)
Mytest();

return 0;
}



实现二:matrix与matrix

#include "stdafx.h"
#include "mtl/matrix.h"
#include "mtl/dense1D.h"
#include "mtl/mtl.h"
#include <atltime.h>

using namespace ATL;
using namespace std;
using namespace mtl;

typedef matrix<double, rectangle<>, dense<>,row_major>::type Matrix;

void Mytest()
{
const Matrix::size_type MAX_ROW = 100, MAX_COL = 100, TestA_ROW =1 , TestA_COL =100;

Matrix TestArray(MAX_ROW,MAX_COL),TestA(TestA_ROW,TestA_COL),Result(TestA_ROW,TestA_COL);


for (Matrix::size_type i=0; i<MAX_ROW; ++i)
{
for (Matrix::size_type j=0; j<MAX_COL; ++j)
{
TestArray(i, j) = Matrix::value_type(j);
}
}

for (Matrix::size_type i=0; i<TestA_ROW; ++i)
{
for (Matrix::size_type j=0; j<TestA_COL; ++j)
{
TestA(i, j) = Matrix::value_type(j);
}
}


double OutResult = 0;
CTime BeginTime = CTime::GetCurrentTime();
for(int j = 0; j< 200000; j++)
{
mult(TestA,TestArray,Result);
}

CTime EndTime = CTime::GetCurrentTime();
CTimeSpan SpendTime = EndTime - BeginTime;
printf("花费 %ld 秒",SpendTime.GetTotalSeconds());
}




int _tmain(int argc, _TCHAR* argv[])
{

for(int i = 0; i<=5; i++)
Mytest();

return 0;
}


PS:这两个程序的乘顺序不同,因为目的只是评估速度,不必深究,谢谢。
...全文
421 15 打赏 收藏 转发到动态 举报
写回复
用AI写文章
15 条回复
切换为时间正序
请发表友善的回复…
发表回复
Dream_soft 2005-05-11
  • 打赏
  • 举报
回复
谢谢大家,我用Icc编译后解决了这个问题,看来是VC7.1的问题了。
Dream_soft 2005-04-19
  • 打赏
  • 举报
回复
呵呵,Java当然没有针对CPU进行手动优化的办法,至于JIT时做不做那要看JVM高不高兴了。我说得是数组访问消除这些东西是一样的。

其实我是想偷懒啊,各位,如果有MTL这些库都帮我们把事做好了,不用自己一个一个去写去优化,省多少事啊,所以还请有知道这个的兄弟替我看看问题出在哪。

昨天下了个Blitz++,有很多东西还是要自己写,不爽,与我这种懒人不合适啊。
frankpzh 2005-04-18
  • 打赏
  • 举报
回复
呵呵,这个关键还是看算法啊,如果你用n^3的矩乘,每次的常数是1个乘法1个加法,n=1000的时候打死不可能1s之内阿,因为计算量就这么多。要变快么就找点复杂度n^2.81啊,之类的算法好了阿...(比如Strassen就是n^2.81的,这个比较著名,网上随便搜就有,CLRS(算法导论)上也有。)
总之,优化代码可能是5s->3s的变化,但是优化时间复杂度就是5s->瞬间的变化了。
dot99 2005-04-18
  • 打赏
  • 举报
回复
算法问题。。。。。。。。。。。
blas 2005-04-18
  • 打赏
  • 举报
回复
你的矩阵乘法维数太少了,我都是测的100以上1000以下!
在我的2.4GHz机器上,1000*1000乘以1000*1000的double矩阵需要1.6秒左右!
我测过matcom4.5的乘法函数,差不多1秒钟!

象这种计算是需要充分利用L1缓存的,需要针对不同的CPU进行不同的优化,
因为每种cpu的L1缓存大小都不一样,而且每种cpu还有各自的扩展指令!

我不相信java能针对不同的cpu来进行优化!
Dream_soft 2005-04-18
  • 打赏
  • 举报
回复
谢谢blas,你给的这个程序思路和Colt中使用的很接近,的确是很有效的提高矩阵乘法效率的算法。

声明:我不是Javaer,虽然我两种语言都在用。我不会声称Java比C++快,因为我的测试结果告诉我未优化的矩阵乘性能在sun JVM下只有65%左右。这还是都动态分配内存的结果,如果C++在编译期就确定了数组尺寸,性能还能有所提高。我用的VC++7.1的编译器,当然,在Intel的编译器下会更快,但Java也有JRocket呢,因此我相信这是比较说明问题的。

不扯远了,我的目的是看看MTL实现整个包的效率。我的课题因为涉及到和其他部分的整合,因些用C++要麻烦一些。因些我想看看MTL到底能提高多少效率,以在运行速度和开发速度之间做出平衡。
llmsn 2005-04-17
  • 打赏
  • 举报
回复
不是又是JAVAERA贴的吧?
blas 2005-04-17
  • 打赏
  • 举报
回复
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;

C[inj] += alpha*sum00;
C[inj+1] += alpha*sum01;
C[inj+n] += alpha*sum10;
C[inj+n+1] += alpha*sum11;
}
}
}
blas 2005-04-17
  • 打赏
  • 举报
回复
我给你贴个矩阵乘法程序,注意一定要用icc来编译!
程序太长,要分段!

#define min(a,b) (((a)<(b)) ? (a):(b))

void dgemm_blk_unroll(const int m, const int n, const int k, const double alpha,
const double *A, const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc)
{
int i, j, l, jj, kk, ik, in, inj, minj, mink, mm2, nm2, kmb;
register double sum00, sum01, sum10, sum11, a0, a1, b0, b1;
const double *pA0, *pA1, *pB;

register int bsize = 48;

mm2 = m%2; // m mod by 2
nm2 = n%2; // n mod by 2
kmb = k%bsize; // k mod by bsize


for (i=0; i<m; ++i) {
for (j=0; j<n; ++j) {
C[i*n+j] *= beta;
}
}

if (mm2) {
for (j=0; j<n; ++j) {
sum00 = 0;
for (l=0; l<k; ++l) {
sum00 += A[l]*B[l*n+j];
}
C[j] += alpha*sum00;
}
}

if (nm2) {
for (kk=0; kk<n; kk+=bsize)
for (i=mm2; i<n; ++i) {
ik = i*k;
sum00 = 0;
mink = min(kk+bsize, k);
for (l=kk; l<mink; ++l) {
sum00 += A[ik+l]*B[l*n];
}
C[i*n] += alpha*sum00;
}
}


for (jj=nm2; jj<n; jj+=bsize){
for (i=mm2; i<m; ++i) {
in = i*n;
ik = i*k;
minj = min(jj+bsize, n);
for (j=jj; j<minj; ++j) {
sum00 = 0;
for (l=0; l<kmb; ++l) {
sum00 += A[ik+l]*B[l*n+j];
}
C[in+j] += alpha*sum00;
}
}
}

for (jj=nm2; jj<n; jj+=bsize)
for (kk=kmb; kk<k; kk+=bsize)
for (i=mm2; i<m; i+=2) {
ik = i*k;
minj = (jj+bsize)<n ? (jj+bsize):n;
for (j=jj; j<minj; j+=2) {
sum00 = 0;
sum01 = 0;
sum10 = 0;
sum11 = 0;
inj = i*n + j;
pA0 = A + ik + kk;
pA1 = pA0 + k;
pB = B + kk*n + j;

a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;
a0 = *pA0++;
a1 = *pA1++;
b0 = pB[0];
b1 = pB[1];
sum00 += a0*b0;
sum01 += a0*b1;
sum10 += a1*b0;
sum11 += a1*b1;
pB += n;


smartduck 2005-04-16
  • 打赏
  • 举报
回复
ding
zengwujun 2005-04-16
  • 打赏
  • 举报
回复
up
yuanyou 2005-04-16
  • 打赏
  • 举报
回复
ding
pcboyxhy 2005-04-16
  • 打赏
  • 举报
回复
算法不好

Java的包是用比你优秀的多的算法写的。
就算是算法相同
程序的代码设计也远比这个先进。
ysbcg 2005-04-16
  • 打赏
  • 举报
回复
呵呵 stl很慢的。。。。
改用自定义的类型吧
还有 要数组访问优化
Dream_soft 2005-04-16
  • 打赏
  • 举报
回复
To pcboyxhy:
这个算法可不是我自己写的啊,是MTL的算法,号称可与Fortran比肩的。所以总觉着是自己程序的问题。

24,854

社区成员

发帖
与我相关
我的任务
社区描述
C/C++ 工具平台和程序库
社区管理员
  • 工具平台和程序库社区
加入社区
  • 近7日
  • 近30日
  • 至今
社区公告
暂无公告

试试用AI创作助手写篇文章吧