ECOCPAK v0.9
|
00001 // Copyright (C) 2011 the authors listed below 00002 // http://ecocpak.sourceforge.net 00003 // 00004 // Authors: 00005 // - Dimitrios Bouzas (bouzas at ieee dot org) 00006 // - Nikolaos Arvanitopoulos (niarvani at ieee dot org) 00007 // - Anastasios Tefas (tefas at aiia dot csd dot auth dot gr) 00008 // 00009 // This file is part of the ECOC PAK C++ library. It is 00010 // provided without any warranty of fitness for any purpose. 00011 // 00012 // You can redistribute this file and/or modify it under 00013 // the terms of the GNU Lesser General Public License (LGPL) 00014 // as published by the Free Software Foundation, either 00015 // version 3 of the License or (at your option) any later 00016 // version. 00017 // (see http://www.opensource.org/licenses for more info) 00018 00019 00022 00023 00024 00046 void 00047 decode 00048 ( 00049 const mat& Xt, 00050 const uvec& lt, 00051 const imat& coding_matrix, 00052 const vector<Classifier*>& classifiers_vector, 00053 const vector<ClassData>& classes_vector, 00054 const int decoding_option, 00055 uvec& predictions, 00056 u32& n_missed, 00057 double& error, 00058 umat& confussion 00059 ) 00060 { 00061 // initialize 00062 n_missed = 0; 00063 error = 0.0; 00064 00065 // number of available classifiers 00066 const u32 n_classifiers = classifiers_vector.size(); 00067 00068 // number of samples 00069 const u32 n_samples = Xt.n_rows; 00070 00071 // vector of winning rows of coding matrix for each sample in Xt 00072 predictions = zeros<uvec>(n_samples); 00073 00074 // active bins of classes 00075 uvec bins = zeros<uvec>(coding_matrix.n_rows); 00076 00077 // find number of initial classes 00078 for(u32 i = 0; i < classes_vector.size(); i++) 00079 { 00080 if(bins[classes_vector[i].ClassLabel() -1] == 0) 00081 { 00082 bins[classes_vector[i].ClassLabel() -1] = 1; 00083 } 00084 00085 } 00086 00087 // active classes (i.e,. not subclasses) 00088 uvec active = find(bins == 1); 00089 00090 // number of classes 00091 const u32 n_classes = active.n_elem; 00092 00093 // allocate confussion matrix 00094 confussion = zeros<umat>(n_classes, n_classes); 00095 00096 // count missed samples according to user entered decoding strategy 00097 switch(decoding_option) 00098 { 00099 // Hamming, Euclidean and Laplacian decoding (generally 00100 // straightforward decoding strategies) 00101 case HAMMING: 00102 case EUCLIDEAN: 00103 case LAPLACIAN: 00104 case HAMMING_ATTENUATED: 00105 case EUCLIDEAN_ATTENUATED: 00106 case LINEAR_LOSS_BASED_DECODING: 00107 case EXPONENTIAL_LOSS_BASED_DECODING: 00108 case INVERSE_HAMMING_DECODING: 00109 case BETA_DENSITY_DECODING: 00110 { 00111 // iterate through testing samples 00112 for(u32 i = 0; i < n_samples; i++) 00113 { 00114 // current sample's codeword 00115 rowvec codeword = zeros<rowvec>(n_classifiers); 00116 00117 for(u32 j = 0; j < n_classifiers; j++) 00118 { 00119 codeword[j] = classifiers_vector[j]->predict(Xt.row(i)); 00120 } 00121 00122 // used to hold the winning label (i.e., the class which has the 00123 // closest codeword to current testing sample's codeword) 00124 predictions[i] = decode_codeword 00125 ( 00126 codeword, 00127 coding_matrix, 00128 decoding_option 00129 ); 00130 00131 // if predicted label is not equal with actual sample's 00132 // label increase number of misclassified samples 00133 if(classes_vector[predictions[i] - 1].ClassLabel() != lt[i]) 00134 { 00135 n_missed++; 00136 confussion(classes_vector[predictions[i] - 1].ClassLabel() - 1, lt[i] - 1)++; 00137 } 00138 else 00139 { 00140 confussion(lt[i] -1, lt[i] - 1)++; 00141 } 00142 00143 } 00144 00145 break; 00146 } 00147 00148 // Loss Weighted Decoding 00149 case LINEAR_LOSS_WEIGHTED_DECODING: 00150 { 00151 n_missed = linear_loss_weighted_decoding 00152 ( 00153 classifiers_vector, 00154 classes_vector, 00155 coding_matrix, 00156 Xt, 00157 lt, 00158 predictions, 00159 confussion 00160 ); 00161 00162 break; 00163 } 00164 00165 // Exponential Loss Weighted Decoding 00166 case EXPONENTIAL_LOSS_WEIGHTED_DECODING: 00167 { 00168 n_missed = exponential_loss_weighted_decoding 00169 ( 00170 classifiers_vector, 00171 classes_vector, 00172 coding_matrix, 00173 Xt, 00174 lt, 00175 predictions, 00176 confussion 00177 ); 00178 00179 break; 00180 } 00181 00182 // Probabilistic Based Decoding 00183 case PROBABILISTIC_BASED_DECODING: 00184 { 00185 n_missed = probabilistic_decoding 00186 ( 00187 classifiers_vector, 00188 classes_vector, 00189 coding_matrix, 00190 Xt, 00191 lt, 00192 predictions, 00193 confussion 00194 ); 00195 00196 break; 00197 } 00198 00199 // User Custom Decoding 00200 case CUSTOM_DECODING: 00201 { 00202 n_missed = custom_decoding 00203 ( 00204 classifiers_vector, 00205 classes_vector, 00206 coding_matrix, 00207 Xt, 00208 lt, 00209 predictions, 00210 confussion 00211 ); 00212 00213 break; 00214 } 00215 00216 default: 00217 { 00218 arma_debug_print("decode(): Unknown Decoding Option"); 00219 } 00220 00221 } 00222 00223 // compute total missclassification error 00224 error = double(n_missed) / double(n_samples); 00225 } 00226 00227 00228