24 #ifndef __GBT_CLASSIFICATION_TRAINING_BATCH_H__
25 #define __GBT_CLASSIFICATION_TRAINING_BATCH_H__
27 #include "algorithms/classifier/classifier_training_batch.h"
28 #include "algorithms/gradient_boosted_trees/gbt_classification_training_types.h"
36 namespace classification
55 template<
typename algorithmFPType, Method method, CpuType cpu>
56 class BatchContainer :
public TrainingContainerIface<batch>
64 DAAL_DEPRECATED BatchContainer(daal::services::Environment::env *daalEnv);
66 DAAL_DEPRECATED ~BatchContainer();
71 DAAL_DEPRECATED services::Status compute() DAAL_C11_OVERRIDE;
72 DAAL_DEPRECATED services::Status setupCompute() DAAL_C11_OVERRIDE;
92 template<
typename algorithmFPType = DAAL_ALGORITHM_FP_TYPE, Method method = defaultDense>
93 class DAAL_EXPORT Batch :
public classifier::training::interface1::Batch
96 typedef classifier::training::interface1::Batch super;
98 typedef typename super::InputType InputType;
99 typedef algorithms::gbt::classification::training::interface1::Parameter ParameterType;
100 typedef algorithms::gbt::classification::training::Result ResultType;
108 DAAL_DEPRECATED Batch(
size_t nClasses);
116 DAAL_DEPRECATED Batch(
const Batch<algorithmFPType, method> &other);
119 DAAL_DEPRECATED ~Batch()
128 DAAL_DEPRECATED ParameterType& parameter() {
return *
static_cast<ParameterType*
>(_par); }
134 DAAL_DEPRECATED
const ParameterType& parameter()
const {
return *
static_cast<const ParameterType*
>(_par); }
140 DAAL_DEPRECATED InputType * getInput() DAAL_C11_OVERRIDE {
return &input; }
146 DAAL_DEPRECATED_VIRTUAL
virtual int getMethod() const DAAL_C11_OVERRIDE {
return(
int)method; }
152 DAAL_DEPRECATED interface1::ResultPtr getResult()
154 return ResultType::cast(_result);
160 DAAL_DEPRECATED services::Status resetResult() DAAL_C11_OVERRIDE
162 _result.reset(
new ResultType());
163 DAAL_CHECK(_result, services::ErrorNullResult);
165 return services::Status();
173 DAAL_DEPRECATED services::SharedPtr<Batch<algorithmFPType, method> > clone()
const
175 return services::SharedPtr<Batch<algorithmFPType, method> >(cloneImpl());
178 DAAL_DEPRECATED_VIRTUAL
virtual services::Status checkComputeParams() DAAL_C11_OVERRIDE;
181 virtual Batch<algorithmFPType, method> * cloneImpl() const DAAL_C11_OVERRIDE
183 return new Batch<algorithmFPType, method>(*this);
186 services::Status allocateResult() DAAL_C11_OVERRIDE
188 ResultPtr res = getResult();
189 DAAL_CHECK(res, services::ErrorNullResult);
190 services::Status s = res->template allocate<algorithmFPType>(&input, ¶meter(), method);
191 _res = _result.get();
197 _ac =
new __DAAL_ALGORITHM_CONTAINER(batch, interface1::BatchContainer, algorithmFPType, method)(&_env);
199 _result.reset(
new ResultType());
215 template<
typename algorithmFPType, Method method, CpuType cpu>
216 class BatchContainer :
public TrainingContainerIface<batch>
224 BatchContainer(daal::services::Environment::env *daalEnv);
231 services::Status compute() DAAL_C11_OVERRIDE;
232 services::Status setupCompute() DAAL_C11_OVERRIDE;
251 template<
typename algorithmFPType = DAAL_ALGORITHM_FP_TYPE, Method method = defaultDense>
252 class DAAL_EXPORT Batch :
public classifier::training::Batch
255 typedef classifier::training::Batch super;
257 typedef typename super::InputType InputType;
258 typedef algorithms::gbt::classification::training::Parameter ParameterType;
259 typedef algorithms::gbt::classification::training::Result ResultType;
267 Batch(
size_t nClasses);
275 Batch(
const Batch<algorithmFPType, method> &other);
287 ParameterType& parameter() {
return *
static_cast<ParameterType*
>(_par); }
293 const ParameterType& parameter()
const {
return *
static_cast<const ParameterType*
>(_par); }
299 InputType * getInput() DAAL_C11_OVERRIDE {
return &input; }
305 virtual int getMethod() const DAAL_C11_OVERRIDE {
return(
int)method; }
311 ResultPtr getResult()
313 return ResultType::cast(_result);
319 services::Status resetResult() DAAL_C11_OVERRIDE
321 _result.reset(
new ResultType());
322 DAAL_CHECK(_result, services::ErrorNullResult);
324 return services::Status();
332 services::SharedPtr<Batch<algorithmFPType, method> > clone()
const
334 return services::SharedPtr<Batch<algorithmFPType, method> >(cloneImpl());
337 virtual services::Status checkComputeParams() DAAL_C11_OVERRIDE;
340 virtual Batch<algorithmFPType, method> * cloneImpl() const DAAL_C11_OVERRIDE
342 return new Batch<algorithmFPType, method>(*this);
345 services::Status allocateResult() DAAL_C11_OVERRIDE
347 ResultPtr res = getResult();
348 DAAL_CHECK(res, services::ErrorNullResult);
349 services::Status s = res->template allocate<algorithmFPType>(&input, ¶meter(), method);
350 _res = _result.get();
356 _ac =
new __DAAL_ALGORITHM_CONTAINER(batch, BatchContainer, algorithmFPType, method)(&_env);
358 _result.reset(
new ResultType());
363 using interface2::BatchContainer;
364 using interface2::Batch;
371 #endif // __LOGIT_BOOST_TRAINING_BATCH_H__
daal::algorithms::gbt::classification::training::interface1::Batch
Trains model of the Gradient Boosted Trees algorithms in the batch processing mode.
Definition: gbt_classification_training_batch.h:93
daal::algorithms::gbt::classification::training::interface1::Batch::parameter
DAAL_DEPRECATED ParameterType & parameter()
Definition: gbt_classification_training_batch.h:128
daal::algorithms::gbt::classification::training::interface2::BatchContainer::compute
services::Status compute() DAAL_C11_OVERRIDE
daal::algorithms::classifier::training::interface1::Result
Provides methods to access final results obtained with the compute() method in the batch processing m...
Definition: classifier_training_types.h:198
daal::algorithms::gbt::classification::training::interface2::Batch::getMethod
virtual int getMethod() const DAAL_C11_OVERRIDE
Definition: gbt_classification_training_batch.h:305
daal::batch
Definition: daal_defines.h:112
daal::algorithms::gbt::classification::training::interface1::BatchContainer::BatchContainer
DAAL_DEPRECATED BatchContainer(daal::services::Environment::env *daalEnv)
daal::algorithms::gbt::classification::training::interface2::BatchContainer::BatchContainer
BatchContainer(daal::services::Environment::env *daalEnv)
daal::algorithms::gbt::classification::training::interface2::Batch::clone
services::SharedPtr< Batch< algorithmFPType, method > > clone() const
Definition: gbt_classification_training_batch.h:332
daal::algorithms::classifier::training::interface1::Batch
Algorithm class for training the classifier model.
Definition: classifier_training_batch.h:58
daal::algorithms::gbt::classification::training::interface1::Batch::parameter
DAAL_DEPRECATED const ParameterType & parameter() const
Definition: gbt_classification_training_batch.h:134
daal::algorithms::gbt::classification::training::interface1::BatchContainer::~BatchContainer
DAAL_DEPRECATED ~BatchContainer()
daal::algorithms::gbt::classification::training::interface2::Batch::getResult
ResultPtr getResult()
Definition: gbt_classification_training_batch.h:311
daal::algorithms::gbt::classification::training::interface1::Batch::getInput
DAAL_DEPRECATED InputType * getInput() DAAL_C11_OVERRIDE
Definition: gbt_classification_training_batch.h:140
daal::algorithms::gbt::classification::training::interface1::Batch::getMethod
virtual DAAL_DEPRECATED_VIRTUAL int getMethod() const DAAL_C11_OVERRIDE
Definition: gbt_classification_training_batch.h:146
daal::algorithms::gbt::classification::training::interface1::BatchContainer::compute
DAAL_DEPRECATED services::Status compute() DAAL_C11_OVERRIDE
daal::services::ErrorNullResult
Definition: error_indexes.h:98
daal::algorithms::gbt::classification::training::interface2::Batch
Trains model of the Gradient Boosted Trees algorithms in the batch processing mode.
Definition: gbt_classification_training_batch.h:252
daal::algorithms::gbt::classification::training::interface1::Batch::~Batch
DAAL_DEPRECATED ~Batch()
Definition: gbt_classification_training_batch.h:119
daal::algorithms::gbt::classification::training::interface1::Batch::clone
DAAL_DEPRECATED services::SharedPtr< Batch< algorithmFPType, method > > clone() const
Definition: gbt_classification_training_batch.h:173
daal::algorithms::gbt::classification::training::interface1::Batch::getResult
DAAL_DEPRECATED interface1::ResultPtr getResult()
Definition: gbt_classification_training_batch.h:152
daal::algorithms::gbt::classification::training::interface2::BatchContainer::~BatchContainer
~BatchContainer()
daal::algorithms::gbt::classification::training::interface2::Batch::~Batch
~Batch()
Definition: gbt_classification_training_batch.h:278
daal::algorithms::classifier::interface1::Parameter
Base class for the parameters of the classification algorithm.
Definition: classifier_model.h:69
daal::algorithms::gbt::classification::training::interface2::Batch::parameter
const ParameterType & parameter() const
Definition: gbt_classification_training_batch.h:293
daal::algorithms::classifier::training::interface1::Input
Base class for the input objects in the training stage of the classification algorithms.
Definition: classifier_training_types.h:110
daal::algorithms::gbt::classification::training::interface2::BatchContainer
Provides methods to run implementations of Gradient Boosted Trees model-based training. This class is associated with daal::algorithms::gbt::classification::training::Batch class.
Definition: gbt_classification_training_batch.h:216
daal::algorithms::gbt::classification::training::interface1::Parameter
Gradient Boosted Trees algorithm parameters.
Definition: gbt_classification_training_types.h:95
daal::algorithms::gbt::classification::training::interface2::Batch::getInput
InputType * getInput() DAAL_C11_OVERRIDE
Definition: gbt_classification_training_batch.h:299
daal::algorithms::gbt::classification::training::interface1::BatchContainer
Provides methods to run implementations of Gradient Boosted Trees model-based training.
Definition: gbt_classification_training_batch.h:56
daal::algorithms::gbt::classification::training::interface2::Batch::parameter
ParameterType & parameter()
Definition: gbt_classification_training_batch.h:287
daal::algorithms::gbt::classification::training::interface1::Batch::input
InputType input
Definition: gbt_classification_training_batch.h:102
daal::algorithms::gbt::classification::training::interface1::Batch::resetResult
DAAL_DEPRECATED services::Status resetResult() DAAL_C11_OVERRIDE
Definition: gbt_classification_training_batch.h:160
daal::algorithms::gbt::classification::training::interface2::Batch::input
InputType input
Definition: gbt_classification_training_batch.h:261
daal::algorithms::TrainingContainerIface
Abstract interface class that provides virtual methods to access and run implementations of the model...
Definition: training.h:52
daal::algorithms::gbt::classification::training::interface2::Batch::resetResult
services::Status resetResult() DAAL_C11_OVERRIDE
Definition: gbt_classification_training_batch.h:319