ECOCPAK v0.9
Classifier_svm.hpp
Go to the documentation of this file.
00001 // Copyright (C) 2011 the authors listed below
00002 // http://ecocpak.sourceforge.net
00003 //
00004 // Authors:
00005 // - Dimitrios Bouzas (bouzas at ieee dot org)
00006 // - Nikolaos Arvanitopoulos (niarvani at ieee dot org)
00007 // - Anastasios Tefas (tefas at aiia dot csd dot auth dot gr)
00008 //
00009 // This file is part of the ECOC PAK C++ library. It is
00010 // provided without any warranty of fitness for any purpose.
00011 //
00012 // You can redistribute this file and/or modify it under
00013 // the terms of the GNU Lesser General Public License (LGPL)
00014 // as published by the Free Software Foundation, either
00015 // version 3 of the License or (at your option) any later
00016 // version.
00017 // (see http://www.opensource.org/licenses for more info)
00018 
00019 
00022 
00023 
00024 #ifndef _CLASSIFIER_SVM_H_
00025 #define _CLASSIFIER_SVM_H_
00026 
00027 
00028 
00029 #include "Classifier.hpp"
00030 
00031 
00032 
00038 class Classifier_svm : public Classifier
00039   {
00040   public:
00041 
00042   // ---------------------------------------------------------------- //
00043   // ------------------------ Constructors -------------------------- //
00044   // ---------------------------------------------------------------- //
00045 
00046   // Copy ctor -- Overloaded
00047   Classifier_svm
00048     (
00049     const Classifier_svm& c
00050     );
00051 
00052   // User defined ctor -- Overloaded
00053   Classifier_svm
00054     (
00055     const mat& A,
00056     const mat& B
00057     );
00058 
00059   // dtor
00060   ~Classifier_svm();
00061 
00062   // ---------------------------------------------------------------- //
00063   // ---------------------- Member Functions ------------------------ //
00064   // ---------------------------------------------------------------- //
00065 
00066   // return prediction value of classifier for input feature vector
00067   double predict(const rowvec& t) const;
00068 
00069   // ---------------------------------------------------------------- //
00070   // -------------------------- Attributes -------------------------- //
00071   // ---------------------------------------------------------------- //
00072 
00073   // LIBSVM SVM model.
00074   struct svm_model model;
00075   };
00076 
00077 
00078 
00086 Classifier_svm::Classifier_svm
00087   (
00088   const Classifier_svm& c
00089   )
00090   {
00091   svm_model* tmp = const_cast<svm_model*>(&(c.model));
00092 
00093   // update SVM model
00094   model = modelcpy(tmp);
00095 
00096   // update the plus and minus classes
00097   pos = c.pos;
00098   neg = c.neg;
00099 
00100   //  number of possitive and negative samples
00101   n_pos = c.n_pos;
00102   n_neg = c.n_neg;
00103 
00104   // update training error
00105   training_error = c.training_error;
00106   }
00107 
00108 
00109 
00118 Classifier_svm::Classifier_svm
00119   (
00120   const mat& A,
00121   const mat& B
00122   )
00123   {
00124   // number of first class feature vectors
00125   const u32 n_samplesA = A.n_rows;
00126 
00127   // number of second class feature vectors
00128   const u32 n_samplesB = B.n_rows;
00129 
00130   // construct data matrix TODO::
00131   mat data_matrix = join_cols(A,B);
00132   colvec labels = join_cols
00133                      (
00134                      ones<colvec>(n_samplesA),
00135                      -ones<colvec>(n_samplesB)
00136                      );
00137 
00138   // LIBSVM problem -- set by read_training_file
00139   struct svm_problem problem;
00140   problem.x = dense_to_sparse(data_matrix);
00141   problem.l = n_samplesA + n_samplesB;
00142   problem.y = labels.memptr();
00143 
00144   // train LIBSVM model & copy
00145   svm_model* m = svm_train(&problem, &param);
00146   model = modelcpy(m);
00147 
00148   // clean up temporary created model
00149   delete_sparse_matrix(problem.x, problem.l);
00150   delete [] m->SV;
00151 
00152   for(u32 i = 0; i < m->nr_class - 1; i++)
00153     {
00154   delete [] m->sv_coef[i];
00155     }
00156 
00157   delete [] m->sv_coef;
00158   delete m->rho;
00159   delete [] m->label;
00160   delete m->probA;
00161   delete m->probB;
00162   delete m->nSV;
00163   delete m;
00164 
00165   // initialize number of possitive and negative samples
00166   n_pos = 0;
00167   n_neg = 0;
00168   }
00169 
00170 
00171 
00184 inline
00185 double
00186 Classifier_svm::predict(const rowvec& t) const
00187   {
00188   // temporary variable to hold classifier prediction
00189   double prediction = 0.0;
00190 
00191   // --- convert Armadillo row vector to LIBSVM type format --- //
00192 
00193   // indices of non zero elements
00194   ucolvec ind = find(t);
00195 
00196   // number of non zero elements
00197   const u32 n_elem = ind.n_elem;
00198 
00199   // allocate space for sparse row
00200   struct svm_node* x_space = new svm_node[n_elem + 1];
00201 
00202   // set last row element of sparse representation
00203   x_space[n_elem].index = -1;
00204   x_space[n_elem].value = 0;
00205 
00206   // fill in the rest of the elements
00207   for(u32 i = 0; i < n_elem; i++)
00208     {
00209     x_space[i].index = ind[i];
00210     x_space[i].value = t[ind[i]];
00211     }
00212 
00213   // compute prediction
00214   svm_predict_values(&model, x_space, &prediction);
00215 
00216   // delete temporary sparse representation
00217   delete [] x_space;
00218 
00219   // return prediction
00220   return prediction;
00221   }
00222 
00223 
00224 
00227 Classifier_svm::~Classifier_svm()
00228   {
00229   delete_sparse_matrix(model.SV, model.l);
00230 
00231   for(u32 i = 0; i < model.nr_class - 1; i++)
00232     {
00233   delete [] model.sv_coef[i];
00234     }
00235 
00236   delete [] model.SV;
00237   delete [] model.sv_coef;
00238   delete model.rho;
00239   delete model.label;
00240   delete model.probA;
00241   delete model.probB;
00242   delete model.nSV;
00243   delete &model;
00244   }
00245 
00246 
00247 
00248 #endif
00249 
00250 
00251 
 All Data Structures Namespaces Files Functions Variables Typedefs Enumerator Defines