ECOCPAK v0.9
|
00001 // Copyright (C) 2011 the authors listed below 00002 // http://ecocpak.sourceforge.net 00003 // 00004 // Authors: 00005 // - Dimitrios Bouzas (bouzas at ieee dot org) 00006 // - Nikolaos Arvanitopoulos (niarvani at ieee dot org) 00007 // - Anastasios Tefas (tefas at aiia dot csd dot auth dot gr) 00008 // 00009 // This file is part of the ECOC PAK C++ library. It is 00010 // provided without any warranty of fitness for any purpose. 00011 // 00012 // You can redistribute this file and/or modify it under 00013 // the terms of the GNU Lesser General Public License (LGPL) 00014 // as published by the Free Software Foundation, either 00015 // version 3 of the License or (at your option) any later 00016 // version. 00017 // (see http://www.opensource.org/licenses for more info) 00018 00019 00022 00023 00024 00025 #ifndef _CLASSIFIER_H_ 00026 #define _CLASSIFIER_H_ 00027 00028 00029 00030 #include <iostream> 00031 #include <vector> 00032 using namespace std; 00033 00034 00035 #include <armadillo> 00036 using namespace arma; 00037 00038 #include "ClassData.hpp" 00039 00040 00041 00044 enum 00045 { 00046 NCC, // - Nearest Mean Classifier. 00047 FLDA, // - Fisher Linear Discriminant Analysis Followed by 00048 // NMC. 00049 SVM, // - Support Vector Machine. 00050 ADABOOST, // - Discrete AdaBoost. 00051 LEAST_SQUARES, // - Sum of Least Error Squares Classifier. 00052 CUSTOM_CLASSIFIER // - Custom Classifier. 00053 }; 00054 00055 00056 00092 class Classifier 00093 { 00094 public: 00095 00096 // ================================================================ // 00097 // || Auxiliary || // 00098 // ================================================================ // 00099 00100 // print classifier info to specified output stream 00101 void print(ostream& out) const; 00102 00103 // ================================================================ // 00104 // || Overloaded Operators || // 00105 // ================================================================ // 00106 00107 // overloaded equallity operator 00108 bool operator==(Classifier& c); 00109 00110 // ================================================================ // 00111 // || Member Functions || // 00112 // ================================================================ // 00113 00114 // return prediction value of classifier for input feature vector 00115 virtual double predict(const rowvec& t) const = 0; 00116 00117 // ================================================================ // 00118 // || Attributes || // 00119 // ================================================================ // 00120 00121 // training error attained by binary classifier on specific problem 00122 double training_error; 00123 00124 // vector which holds pointers to ClassData objects that the 00125 // classifier considers them as possitive 00126 vector<ClassData*> pos; 00127 00128 // vector which holds pointers to ClassData objects that the 00129 // classifier considers them as negative 00130 vector<ClassData*> neg; 00131 00132 // number of possitive samples considered by binary classifier 00133 u32 n_pos; 00134 00135 // number of negative samples considered by binary classifier 00136 u32 n_neg; 00137 }; 00138 00139 00140 00141 // ================================================================== // 00142 // || Auxiliary || // 00143 // ================================================================== // 00144 00145 00157 void 00158 Classifier::print(ostream& out = cout) const 00159 { 00160 out << "--- Classifier info ---" << endl; 00161 00162 // classes marked with +1 00163 for(u32 i = 0; i < pos.size(); i++) 00164 { 00165 out << pos[i]->ClassIndex() << " "; 00166 } 00167 00168 out << " vs "; 00169 00170 // classes marked with -1 00171 for(u32 i = 0; i < neg.size(); i++) 00172 { 00173 out << neg[i]->ClassIndex() << " "; 00174 } 00175 00176 out << endl; 00177 00178 out << "Training Error: " << training_error << endl; 00179 } 00180 00181 00182 00183 // ================================================================== // 00184 // || Overloaded Operators || // 00185 // ================================================================== // 00186 00187 00188 00191 bool 00192 Classifier::operator==(Classifier& c) 00193 { 00194 // check wether vector of classes marked with +1 and vector of 00195 // classes marked with -1 are equal 00196 if(c.pos.size() != pos.size() || c.neg.size() != neg.size()) 00197 { 00198 return false; 00199 } 00200 00201 // check wether classifiers have the same classes marked with +1 00202 for(u32 i = 0; i < pos.size(); i++) 00203 { 00204 if(pos[i]->ClassIndex() != c.pos[i]->ClassIndex()) 00205 { 00206 return false; 00207 } 00208 00209 } 00210 00211 // check wether classifiers have the same classes marked with -1 00212 for(u32 i = 0; i < neg.size(); i++) 00213 { 00214 if(neg[i]->ClassIndex() != c.neg[i]->ClassIndex()) 00215 { 00216 return false; 00217 } 00218 00219 } 00220 00221 return true; 00222 } 00223 00224 00225 00226 #endif 00227 00228 00229