C++ API Reference for Intel® Data Analytics Acceleration Library 2020 Update 1

gbt_classification_training_batch.h
1 /* file: gbt_classification_training_batch.h */
2 /*******************************************************************************
3 * Copyright 2014-2020 Intel Corporation
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 /*
19 //++
20 // Implementation of the interface for Gradient Boosted Trees model-based training
21 //--
22 */
23 
24 #ifndef __GBT_CLASSIFICATION_TRAINING_BATCH_H__
25 #define __GBT_CLASSIFICATION_TRAINING_BATCH_H__
26 
27 #include "algorithms/classifier/classifier_training_batch.h"
28 #include "algorithms/gradient_boosted_trees/gbt_classification_training_types.h"
29 
30 namespace daal
31 {
32 namespace algorithms
33 {
34 namespace gbt
35 {
36 namespace classification
37 {
38 namespace training
39 {
40 namespace interface1
41 {
55 template<typename algorithmFPType, Method method, CpuType cpu>
56 class BatchContainer : public TrainingContainerIface<batch>
57 {
58 public:
64  DAAL_DEPRECATED BatchContainer(daal::services::Environment::env *daalEnv);
66  DAAL_DEPRECATED ~BatchContainer();
71  DAAL_DEPRECATED services::Status compute() DAAL_C11_OVERRIDE;
72  DAAL_DEPRECATED services::Status setupCompute() DAAL_C11_OVERRIDE;
73 };
74 
92 template<typename algorithmFPType = DAAL_ALGORITHM_FP_TYPE, Method method = defaultDense>
93 class DAAL_EXPORT Batch : public classifier::training::interface1::Batch
94 {
95 public:
96  typedef classifier::training::interface1::Batch super;
97 
98  typedef typename super::InputType InputType;
99  typedef algorithms::gbt::classification::training::interface1::Parameter ParameterType;
100  typedef algorithms::gbt::classification::training::Result ResultType;
101 
102  InputType input;
108  DAAL_DEPRECATED Batch(size_t nClasses);
109 
116  DAAL_DEPRECATED Batch(const Batch<algorithmFPType, method> &other);
117 
119  DAAL_DEPRECATED ~Batch()
120  {
121  delete _par;
122  }
123 
128  DAAL_DEPRECATED ParameterType& parameter() { return *static_cast<ParameterType*>(_par); }
129 
134  DAAL_DEPRECATED const ParameterType& parameter() const { return *static_cast<const ParameterType*>(_par); }
135 
140  DAAL_DEPRECATED InputType * getInput() DAAL_C11_OVERRIDE { return &input; }
141 
146  DAAL_DEPRECATED_VIRTUAL virtual int getMethod() const DAAL_C11_OVERRIDE { return(int)method; }
147 
152  DAAL_DEPRECATED interface1::ResultPtr getResult()
153  {
154  return ResultType::cast(_result);
155  }
156 
160  DAAL_DEPRECATED services::Status resetResult() DAAL_C11_OVERRIDE
161  {
162  _result.reset(new ResultType());
163  DAAL_CHECK(_result, services::ErrorNullResult);
164  _res = NULL;
165  return services::Status();
166  }
167 
173  DAAL_DEPRECATED services::SharedPtr<Batch<algorithmFPType, method> > clone() const
174  {
175  return services::SharedPtr<Batch<algorithmFPType, method> >(cloneImpl());
176  }
177 
178  DAAL_DEPRECATED_VIRTUAL virtual services::Status checkComputeParams() DAAL_C11_OVERRIDE;
179 
180 protected:
181  virtual Batch<algorithmFPType, method> * cloneImpl() const DAAL_C11_OVERRIDE
182  {
183  return new Batch<algorithmFPType, method>(*this);
184  }
185 
186  services::Status allocateResult() DAAL_C11_OVERRIDE
187  {
188  ResultPtr res = getResult();
189  DAAL_CHECK(res, services::ErrorNullResult);
190  services::Status s = res->template allocate<algorithmFPType>(&input, &parameter(), method);
191  _res = _result.get();
192  return s;
193  }
194 
195  void initialize()
196  {
197  _ac = new __DAAL_ALGORITHM_CONTAINER(batch, interface1::BatchContainer, algorithmFPType, method)(&_env);
198  _in = &input;
199  _result.reset(new ResultType());
200  }
201 };
203 } // namespace interface1
204 
205 namespace interface2
206 {
215 template<typename algorithmFPType, Method method, CpuType cpu>
216 class BatchContainer : public TrainingContainerIface<batch>
217 {
218 public:
224  BatchContainer(daal::services::Environment::env *daalEnv);
226  ~BatchContainer();
231  services::Status compute() DAAL_C11_OVERRIDE;
232  services::Status setupCompute() DAAL_C11_OVERRIDE;
233 };
251 template<typename algorithmFPType = DAAL_ALGORITHM_FP_TYPE, Method method = defaultDense>
252 class DAAL_EXPORT Batch : public classifier::training::Batch
253 {
254 public:
255  typedef classifier::training::Batch super;
256 
257  typedef typename super::InputType InputType;
258  typedef algorithms::gbt::classification::training::Parameter ParameterType;
259  typedef algorithms::gbt::classification::training::Result ResultType;
260 
261  InputType input;
267  Batch(size_t nClasses);
268 
275  Batch(const Batch<algorithmFPType, method> &other);
276 
278  ~Batch()
279  {
280  delete _par;
281  }
282 
287  ParameterType& parameter() { return *static_cast<ParameterType*>(_par); }
288 
293  const ParameterType& parameter() const { return *static_cast<const ParameterType*>(_par); }
294 
299  InputType * getInput() DAAL_C11_OVERRIDE { return &input; }
300 
305  virtual int getMethod() const DAAL_C11_OVERRIDE { return(int)method; }
306 
311  ResultPtr getResult()
312  {
313  return ResultType::cast(_result);
314  }
315 
319  services::Status resetResult() DAAL_C11_OVERRIDE
320  {
321  _result.reset(new ResultType());
322  DAAL_CHECK(_result, services::ErrorNullResult);
323  _res = NULL;
324  return services::Status();
325  }
326 
332  services::SharedPtr<Batch<algorithmFPType, method> > clone() const
333  {
334  return services::SharedPtr<Batch<algorithmFPType, method> >(cloneImpl());
335  }
336 
337  virtual services::Status checkComputeParams() DAAL_C11_OVERRIDE;
338 
339 protected:
340  virtual Batch<algorithmFPType, method> * cloneImpl() const DAAL_C11_OVERRIDE
341  {
342  return new Batch<algorithmFPType, method>(*this);
343  }
344 
345  services::Status allocateResult() DAAL_C11_OVERRIDE
346  {
347  ResultPtr res = getResult();
348  DAAL_CHECK(res, services::ErrorNullResult);
349  services::Status s = res->template allocate<algorithmFPType>(&input, &parameter(), method);
350  _res = _result.get();
351  return s;
352  }
353 
354  void initialize()
355  {
356  _ac = new __DAAL_ALGORITHM_CONTAINER(batch, BatchContainer, algorithmFPType, method)(&_env);
357  _in = &input;
358  _result.reset(new ResultType());
359  }
360 };
362 } // namespace interface2
363 using interface2::BatchContainer;
364 using interface2::Batch;
365 
366 } // namespace daal::algorithms::gbt::classification::training
367 }
368 }
369 }
370 } // namespace daal
371 #endif // __LOGIT_BOOST_TRAINING_BATCH_H__
daal::algorithms::gbt::classification::training::interface1::Batch
Trains model of the Gradient Boosted Trees algorithms in the batch processing mode.
Definition: gbt_classification_training_batch.h:93
daal::algorithms::gbt::classification::training::interface1::Batch::parameter
DAAL_DEPRECATED ParameterType & parameter()
Definition: gbt_classification_training_batch.h:128
daal::algorithms::gbt::classification::training::interface2::BatchContainer::compute
services::Status compute() DAAL_C11_OVERRIDE
daal::algorithms::classifier::training::interface1::Result
Provides methods to access final results obtained with the compute() method in the batch processing m...
Definition: classifier_training_types.h:198
daal::algorithms::gbt::classification::training::interface2::Batch::getMethod
virtual int getMethod() const DAAL_C11_OVERRIDE
Definition: gbt_classification_training_batch.h:305
daal::batch
Definition: daal_defines.h:112
daal::algorithms::gbt::classification::training::interface1::BatchContainer::BatchContainer
DAAL_DEPRECATED BatchContainer(daal::services::Environment::env *daalEnv)
daal::algorithms::gbt::classification::training::interface2::BatchContainer::BatchContainer
BatchContainer(daal::services::Environment::env *daalEnv)
daal::algorithms::gbt::classification::training::interface2::Batch::clone
services::SharedPtr< Batch< algorithmFPType, method > > clone() const
Definition: gbt_classification_training_batch.h:332
daal::algorithms::classifier::training::interface1::Batch
Algorithm class for training the classifier model.
Definition: classifier_training_batch.h:58
daal::algorithms::gbt::classification::training::interface1::Batch::parameter
DAAL_DEPRECATED const ParameterType & parameter() const
Definition: gbt_classification_training_batch.h:134
daal::algorithms::gbt::classification::training::interface1::BatchContainer::~BatchContainer
DAAL_DEPRECATED ~BatchContainer()
daal::algorithms::gbt::classification::training::interface2::Batch::getResult
ResultPtr getResult()
Definition: gbt_classification_training_batch.h:311
daal::algorithms::gbt::classification::training::interface1::Batch::getInput
DAAL_DEPRECATED InputType * getInput() DAAL_C11_OVERRIDE
Definition: gbt_classification_training_batch.h:140
daal::algorithms::gbt::classification::training::interface1::Batch::getMethod
virtual DAAL_DEPRECATED_VIRTUAL int getMethod() const DAAL_C11_OVERRIDE
Definition: gbt_classification_training_batch.h:146
daal::algorithms::gbt::classification::training::interface1::BatchContainer::compute
DAAL_DEPRECATED services::Status compute() DAAL_C11_OVERRIDE
daal::services::ErrorNullResult
Definition: error_indexes.h:98
daal::algorithms::gbt::classification::training::interface2::Batch
Trains model of the Gradient Boosted Trees algorithms in the batch processing mode.
Definition: gbt_classification_training_batch.h:252
daal::algorithms::gbt::classification::training::interface1::Batch::~Batch
DAAL_DEPRECATED ~Batch()
Definition: gbt_classification_training_batch.h:119
daal::algorithms::gbt::classification::training::interface1::Batch::clone
DAAL_DEPRECATED services::SharedPtr< Batch< algorithmFPType, method > > clone() const
Definition: gbt_classification_training_batch.h:173
daal::algorithms::gbt::classification::training::interface1::Batch::getResult
DAAL_DEPRECATED interface1::ResultPtr getResult()
Definition: gbt_classification_training_batch.h:152
daal::algorithms::gbt::classification::training::interface2::BatchContainer::~BatchContainer
~BatchContainer()
daal::algorithms::gbt::classification::training::interface2::Batch::~Batch
~Batch()
Definition: gbt_classification_training_batch.h:278
daal::algorithms::classifier::interface1::Parameter
Base class for the parameters of the classification algorithm.
Definition: classifier_model.h:69
daal::algorithms::gbt::classification::training::interface2::Batch::parameter
const ParameterType & parameter() const
Definition: gbt_classification_training_batch.h:293
daal::algorithms::classifier::training::interface1::Input
Base class for the input objects in the training stage of the classification algorithms.
Definition: classifier_training_types.h:110
daal::algorithms::gbt::classification::training::interface2::BatchContainer
Provides methods to run implementations of Gradient Boosted Trees model-based training. This class is associated with daal::algorithms::gbt::classification::training::Batch class.
Definition: gbt_classification_training_batch.h:216
daal::algorithms::gbt::classification::training::interface1::Parameter
Gradient Boosted Trees algorithm parameters.
Definition: gbt_classification_training_types.h:95
daal::algorithms::gbt::classification::training::interface2::Batch::getInput
InputType * getInput() DAAL_C11_OVERRIDE
Definition: gbt_classification_training_batch.h:299
daal::algorithms::gbt::classification::training::interface1::BatchContainer
Provides methods to run implementations of Gradient Boosted Trees model-based training.
Definition: gbt_classification_training_batch.h:56
daal::algorithms::gbt::classification::training::interface2::Batch::parameter
ParameterType & parameter()
Definition: gbt_classification_training_batch.h:287
daal::algorithms::gbt::classification::training::interface1::Batch::input
InputType input
Definition: gbt_classification_training_batch.h:102
daal::algorithms::gbt::classification::training::interface1::Batch::resetResult
DAAL_DEPRECATED services::Status resetResult() DAAL_C11_OVERRIDE
Definition: gbt_classification_training_batch.h:160
daal::algorithms::gbt::classification::training::interface2::Batch::input
InputType input
Definition: gbt_classification_training_batch.h:261
daal::algorithms::TrainingContainerIface
Abstract interface class that provides virtual methods to access and run implementations of the model...
Definition: training.h:52
daal::algorithms::gbt::classification::training::interface2::Batch::resetResult
services::Status resetResult() DAAL_C11_OVERRIDE
Definition: gbt_classification_training_batch.h:319

For more complete information about compiler optimizations, see our Optimization Notice.