24 #ifndef __ADA_BOOST_MODEL_H__
25 #define __ADA_BOOST_MODEL_H__
27 #include "algorithms/algorithm.h"
28 #include "data_management/data/homogen_numeric_table.h"
29 #include "algorithms/boosting/boosting_model.h"
30 #include "algorithms/classifier/classifier_model.h"
31 #include "algorithms/classifier/classifier_training_batch.h"
32 #include "algorithms/classifier/classifier_predict.h"
59 struct DAAL_EXPORT Parameter :
public boosting::Parameter
62 DAAL_DEPRECATED Parameter();
71 DAAL_DEPRECATED Parameter(services::SharedPtr<weak_learner::training::Batch> wlTrainForParameter,
72 services::SharedPtr<weak_learner::prediction::Batch> wlPredictForParameter,
73 double acc = 0.0,
size_t maxIter = 10);
75 double accuracyThreshold;
78 DAAL_DEPRECATED services::Status check() const DAAL_C11_OVERRIDE;
90 class DAAL_EXPORT Model : public boosting::Model
93 DECLARE_MODEL(Model, classifier::Model)
102 template <
typename modelFPType>
103 DAAL_EXPORT DAAL_DEPRECATED Model(
size_t nFeatures, modelFPType dummy);
109 DAAL_DEPRECATED Model() : boosting::Model(), _alpha() {}
117 template<
typename modelFPType>
118 DAAL_EXPORT DAAL_DEPRECATED
static services::SharedPtr<Model> create(
size_t nFeatures, services::Status *stat = NULL);
128 DAAL_DEPRECATED data_management::NumericTablePtr getAlpha()
const;
131 data_management::NumericTablePtr _alpha;
133 template<
typename Archive,
bool onDeserialize>
134 services::Status serialImpl(Archive *arch)
136 services::Status st = boosting::Model::serialImpl<Archive, onDeserialize>(arch);
139 arch->setSharedPtrObj(_alpha);
144 template <
typename modelFPType>
145 DAAL_EXPORT Model(
size_t nFeatures, modelFPType dummy, services::Status &st);
148 typedef services::SharedPtr<Model> ModelPtr;
156 enum ResultToComputeId
158 computeWeakLearnersErrors = 0x00000001ULL,
176 struct DAAL_EXPORT Parameter :
public classifier::Parameter
182 Parameter(
size_t nClasses = 2);
194 Parameter(services::SharedPtr<classifier::training::Batch> wlTrainForParameter,
195 services::SharedPtr<classifier::prediction::Batch> wlPredictForParameter,
196 double acc = 0.0,
size_t maxIter = 10,
double learnRate = 1.0, DAAL_UINT64 resToCompute = computeWeakLearnersErrors,
size_t nCl = 2);
198 services::SharedPtr<classifier::training::Batch> weakLearnerTraining;
199 services::SharedPtr<classifier::prediction::Batch> weakLearnerPrediction;
200 double accuracyThreshold;
201 size_t maxIterations;
203 DAAL_UINT64 resultsToCompute;
204 services::Status check() const DAAL_C11_OVERRIDE;
216 class DAAL_EXPORT Model : public classifier::Model
219 DECLARE_MODEL(Model, classifier::Model)
228 template <
typename modelFPType>
229 DAAL_EXPORT Model(
size_t nFeatures, modelFPType dummy);
235 Model(
size_t nFeatures = 0) : _models(new data_management::DataCollection()), _nFeatures(nFeatures), _alpha() {}
243 template<
typename modelFPType>
244 DAAL_EXPORT
static services::SharedPtr<Model> create(
size_t nFeatures, services::Status *stat = NULL);
252 size_t getNumberOfWeakLearners()
const;
259 classifier::ModelPtr getWeakLearnerModel(
size_t idx)
const;
265 void addWeakLearnerModel(classifier::ModelPtr model);
270 void clearWeakLearnerModels();
276 size_t getNumberOfFeatures() const DAAL_C11_OVERRIDE {
return _nFeatures; }
284 data_management::NumericTablePtr getAlpha()
const;
288 data_management::DataCollectionPtr _models;
289 data_management::NumericTablePtr _alpha;
291 template<
typename Archive,
bool onDeserialize>
292 services::Status serialImpl(Archive *arch)
295 DAAL_CHECK_STATUS(st, (classifier::Model::serialImpl<Archive, onDeserialize>(arch)));
296 arch->set(_nFeatures);
297 arch->setSharedPtrObj(_models);
298 arch->setSharedPtrObj(_alpha);
303 template <
typename modelFPType>
304 DAAL_EXPORT Model(
size_t nFeatures, modelFPType dummy, services::Status &st);
307 typedef services::SharedPtr<Model> ModelPtr;
310 using interface2::Parameter;
311 using interface2::Model;
312 using interface2::ModelPtr;
daal::algorithms::adaboost::interface2::Parameter::weakLearnerTraining
services::SharedPtr< classifier::training::Batch > weakLearnerTraining
Definition: adaboost_model.h:198
daal::algorithms::adaboost::interface2::Model::getNumberOfFeatures
size_t getNumberOfFeatures() const DAAL_C11_OVERRIDE
Definition: adaboost_model.h:276
daal::algorithms::adaboost::interface1::Model::Model
DAAL_DEPRECATED Model()
Definition: adaboost_model.h:109
daal::algorithms::adaboost::interface2::Parameter::accuracyThreshold
double accuracyThreshold
Definition: adaboost_model.h:200
daal::algorithms::adaboost::interface2::Parameter::maxIterations
size_t maxIterations
Definition: adaboost_model.h:201
daal::algorithms::adaboost::interface2::Parameter::weakLearnerPrediction
services::SharedPtr< classifier::prediction::Batch > weakLearnerPrediction
Definition: adaboost_model.h:199
daal::algorithms::adaboost::interface1::Parameter::accuracyThreshold
double accuracyThreshold
Definition: adaboost_model.h:75
daal::algorithms::adaboost::interface2::Parameter::resultsToCompute
DAAL_UINT64 resultsToCompute
Definition: adaboost_model.h:203
daal::algorithms::elastic_net::training::model
Definition: elastic_net_training_types.h:109
daal::algorithms::adaboost::interface1::Parameter
AdaBoost algorithm parameters.
Definition: adaboost_model.h:59
daal::algorithms::adaboost::interface1::Model
Model of the classifier trained by the adaboost::training::Batch algorithm.
Definition: adaboost_model.h:90
daal::algorithms::adaboost::ResultToComputeId
ResultToComputeId
Definition: adaboost_model.h:156
daal::algorithms::adaboost::interface2::Model::Model
Model(size_t nFeatures=0)
Definition: adaboost_model.h:235
daal::algorithms::adaboost::interface2::Parameter::learningRate
double learningRate
Definition: adaboost_model.h:202
daal::algorithms::adaboost::interface1::Parameter::maxIterations
size_t maxIterations
Definition: adaboost_model.h:76