稀疏矩阵模板

Posted by Harid三月 - 29 - 2011 Leave comments

稀疏矩阵就是一个包含大量零元素的矩阵,具体零元素在矩阵中占多大的比例并没有明确的界定,所以稀疏矩阵也只是一个意识形态上的概念。但是,稀疏矩阵的实际应用意义很大。例如,建立计算机网络时,用999条线路把1000个站点连接起来,用以表示这个网络的连接矩阵有1000×1000个矩阵元素,其中只有1998个非零元素,却有998002个零元素。显然,把所有的零元素都存在计算机中是不经济的,所以必须考虑对稀疏矩阵的压缩存储表示。

一种比较常用的表示稀疏矩阵的方法是用三元组表。每一个三元组为<row, cloumn, value>,它能唯一确定一个矩阵元素。

下面是一个比较蛋疼的C++稀疏矩阵模板类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
#ifndef SPARSEMATRIX_H_
#define SPARSEMATRIX_H_
#include <iostream>
template <class Type> class SparseMatrix;
template <class Type> std::ostream & operator<< (std::ostream & out, const SparseMatrix<Type> & instance);
template <class Type>
class SparseMatrix{
private:
    int mRows;    // 矩阵的行
    int mCols;   //  矩阵的列
    int mItems;   // 矩阵中非零数的个数
    struct Trituple{
        int row;   // 记录某个矩阵元素所在的行
        int col;   //  记录某个矩阵元素所在的列
        Type value; // 记录某个矩阵元素的值
        Trituple* next;
        Trituple() : row(0), col(0), value(0), next(0){};
    } * smArray; // 存储所有三元组的链表
protected:
    void getMatrix(); // 获得三元组的输入
    Type find(int row, int col)const; // 在三元组表中查找某一个矩阵元素(此函数未优化)
public:
    SparseMatrix();
    ~SparseMatrix();
    SparseMatrix(int rows, int cols);
    SparseMatrix(SparseMatrix &); //Copy constructor.
    SparseMatrix<Type> transpose(); // Matrix thranspose.
    SparseMatrix<Type> add(SparseMatrix<Type> & b); //(*this) + b.
    SparseMatrix<Type> multiply(SparseMatrix<Type> & b); //(*this) * b.
    int getRows()const{
        return this->mRows; }
    int getCols()const{
        return this->mCols; }
    int getItems()const{
        return this->mItems; }
    Trituple * getArray()const{
        return this->smArray; }
    friend std::ostream & operator<< <Type>(std::ostream & out, const SparseMatrix<Type> & instance);
};
// End of class statement.
template<class Type>
SparseMatrix<Type>::SparseMatrix(){
    this->mRows = 0;
    this->mCols = 0;
    this->mItems = 0;
    this->smArray = 0;
}
template <class Type>
SparseMatrix<Type>::SparseMatrix(int rows, int cols){
    this->mRows = rows;
    this->mCols = cols;
    this->getMatrix();
}
template <class Type>
SparseMatrix<Type>::SparseMatrix(SparseMatrix & source){
    this->mRows = source.getRows();
    this->mCols = source.getCols();
    this->mItems = source.getItems();
    Trituple* create = 0;
    Trituple* last =  0;
    for(int j=0; j<this->getItems(); j++){
        create = new Trituple;
        if(j == 0)
            this->smArray = create;
        else
            last->next = create;
        last = create;
    }
    Trituple* tempD = this->getArray();
    Trituple* tempS = source.getArray();
    for(int i=0; i<this->getItems(); i++){
        tempD->row = tempS->row;
        tempD->col = tempS->col;
        tempD->value = tempS->value;
        tempD = tempD->next;
        tempS = tempS->next;
    }
}
template <class Type>
SparseMatrix<Type>  SparseMatrix<Type>:: transpose(){
    SparseMatrix<Type> b;
    b.mRows = this->getCols();
    b.mCols = this->getRows();
    b.mItems = this->getItems();
    Trituple* create = 0;
    Trituple* last =  0;
    for(int j=0; j<b.getItems(); j++){
        create = new Trituple;
        if(j == 0){
            b.smArray = create;
            last = create;
        }
        else{
            last->next = create;
            last = create;
        }
    }
    Trituple* temp1 = this->getArray();
    Trituple* temp2 = b.getArray();
    if(this->mItems > 0){
        for(int k=0; k<this->getCols(); k++){
            for(int i=0;i<this->getItems();i++){
                if(temp1->col == k){
                    temp2->row = k;
                    temp2->col = temp1->row;
                    temp2->value = temp1->value;
                    temp2 = temp2->next;
                }
                temp1 = temp1->next;
            }
            temp1 = this->getArray();
        }
    }
    return b;
}
template <class Type>
SparseMatrix<Type> SparseMatrix<Type>::add(SparseMatrix<Type> & b){
    if(this->getRows() != b.getRows()){
        std::cerr<<" Illegal addition! Quit now."<<std::endl;
        SparseMatrix nop;
        return nop;
    }
    SparseMatrix<Type> result;
    result.mCols = b.getCols();
    result.mRows = b.getRows();
    Trituple* tempA = this->getArray();
    Trituple* tempB = b.getArray();
    Trituple* tempR = 0;
    Trituple* last = 0;
    int countA = this->getItems();
    int countB = b.getItems();
    int countR = 0;
    while(1){
        if(countA != 0 || countB != 0){
            if(countA != 0 && countB != 0){
                if(tempA->row < tempB->row){
                    tempR = new Trituple;
                    countR++;
                    if(result.getArray() == 0)
                        result.smArray = tempR;
                    else
                        last->next = tempR;
                    last = tempR;
                    tempR->col = tempA->col;
                    tempR->row = tempA->row;
                    tempR->value = tempA->value;
                    tempA = tempA->next;
                    countA--;
                }
                else if(tempA->row == tempB->row){
                    if(tempA->col < tempB->col){
                        tempR = new Trituple;
                        countR++;
                        if(result.getArray() == 0)
                            result.smArray = tempR;
                        else
                            last->next = tempR;
                        last = tempR;
                        tempR->col = tempA->col;
                        tempR->row = tempA->row;
                        tempR->value = tempA->value;
                        tempA = tempA->next;
                        countA--;
                    }
                    else if(tempA->col == tempB->col){
                        if((tempA->value + tempB->value) != 0){
                        tempR = new Trituple;
                        countR++;
                        if(result.getArray() == 0)
                            result.smArray = tempR;
                        else
                            last->next = tempR;
                        last = tempR;
                        tempR->col = tempA->col;
                        tempR->row = tempA->row;
                        tempR->value = tempA->value + tempB->value;
                        tempA = tempA->next;
                        tempB = tempB->next;
                        countA--;
                        countB--;
                        }
                        else if((tempA->value + tempB->value) == 0){
                            tempA = tempA->next;
                            tempB = tempB->next;
                            countA--;
                            countB--;
                        }
                    }
                    else if(tempA->col > tempB->col){
                        tempR = new Trituple;
                        countR++;
                        if(result.getArray() == 0)
                            result.smArray = tempR;
                        else
                            last->next = tempR;
                        last = tempR;
                        tempR->col = tempB->col;
                        tempR->row = tempB->row;
                        tempR->value = tempB->value;
                        tempB = tempB->next;
                        countB--;
                    }
                }
                else if(tempA->row > tempB->row){
                    tempR = new Trituple;
                    countR++;
                    if(result.getArray() == 0)
                        result.smArray = tempR;
                    else
                        last->next = tempR;
                    last = tempR;
                    tempR->col = tempB->col;
                    tempR->row = tempB->row;
                    tempR->value = tempB->value;
                    tempB = tempB->next;
                    countB--;
                }
            }
            else if(countA == 0){
                tempR = new Trituple;
                countR++;
                if(result.getArray() == 0)
                    result.smArray = tempR;
                else
                    last->next = tempR;
                last = tempR;
                tempR->col = tempB->col;
                tempR->row = tempB->row;
                tempR->value = tempB->value;
                tempB = tempB->next;
                countB--;
            }
            else if(countB == 0){
                tempR = new Trituple;
                countR++;
                if(result.getArray() == 0)
                    result.smArray = tempR;
                else
                    last->next = tempR;
                last = tempR;
                tempR->col = tempA->col;
                tempR->row = tempA->row;
                tempR->value = tempA->value;
                tempA = tempA->next;
                countA--;
            }
        }
        else{
            break;
        }
    }
    result.mItems = countR;
    return result;
}
template <class Type>
SparseMatrix<Type> SparseMatrix<Type>::multiply(SparseMatrix<Type> & b){
    if(this->getCols() != b.getRows()){
        std::cerr<<" Illegal addition! Quit now."<<std::endl;
        SparseMatrix nop;
        return nop;
    }
    SparseMatrix<Type> result;
    result.mRows = this->mRows;
    result.mCols = b.mCols;
    Type temp = 0;
    Type tempA, tempB;
    Trituple* tempResult = 0;
    Trituple* last = 0;
    int count = 0;
    for(int i=0; i<this->getRows(); i++){
        for(int j=0; j<b.getCols(); j++){
            for(int k=0; k<this->getCols(); k++){
                tempA = this->find(i, k);
                if(tempA != 0){
                    tempB = b.find(k, j);
                    temp += tempA * tempB;
                }
            }
            if(temp != 0){
                tempResult = new Trituple;
                count++;
                if(result.getArray() == 0)
                    result.smArray = tempResult;
                else
                    last->next = tempResult;
                last = tempResult;
                tempResult->row = i;
                tempResult->col = j;
                tempResult->value = temp;
                temp = 0;
            }
        }
    }
    result.mItems = count;
    return result;
}
template <class Type>
void SparseMatrix<Type>::getMatrix(){
    using namespace std;
    Trituple* temp = 0;
    Trituple* last = 0;
    cout <<"\n一共有多少个非零值?\n";
    cin>>this->mItems;
    cout <<"\n按如下方式每次输入一个非空三元组: \n\tRow\tColumn\tValue\n\t0  \t0  \t10\n\n";
    for(int i=0; i<this->getItems(); i++){
        temp = new Trituple;
        if(i == 0)
            this->smArray = temp;
        else
            last->next = temp;
        last = temp;
        cout <<"第"<< i+1<<"组: ";
        cin>>temp->row >>temp->col >> temp->value;
    }
}
template <class Type>
Type SparseMatrix<Type>::find(int row, int col)const{
    Trituple* pt = this->getArray();
    for(int i=0; i<this->getItems(); i++){
        if(pt->row == row && pt->col == col){
            return pt->value;
        }
        else
            pt = pt->next;
    }
    return 0;
}
template <class Type>
std::ostream & operator<< (std::ostream & out, const SparseMatrix<Type> & instance){
    using namespace std;
    out <<"The matrix is as follows:\n";
    Type temp = 0;
    for(int i=0; i<instance.mRows; i++){
        for(int j=0; j<instance.mCols; j++){
            temp = instance.find(i, j);
            if(temp != 0)
                out<<" "<<temp<<"\t";
            else
                out<<" 0\t";
        }
        out <<endl;
    }
    return out;
}
template <class Type>
SparseMatrix<Type>::~SparseMatrix(){
    delete [] smArray;
}
#endif

如果您有更好的代码,敬请分享!

   声明:本文采用 BY-NC-SA 协议进行授权 | 星期九
   原创文章转载请注明:转自《稀疏矩阵模板


分享按钮