24 #ifndef __SGD_TYPES_H__
25 #define __SGD_TYPES_H__
27 #include "algorithms/algorithm.h"
28 #include "data_management/data/numeric_table.h"
29 #include "data_management/data/homogen_numeric_table.h"
30 #include "services/daal_defines.h"
31 #include "algorithms/optimization_solver/iterative_solver/iterative_solver_types.h"
32 #include "algorithms/engines/mt19937/mt19937.h"
38 namespace optimization_solver
69 pastUpdateVector = iterative_solver::lastOptionalData + 1,
70 pastWorkValue = pastUpdateVector + 1 ,
71 lastOptionalData = pastWorkValue
87 struct DAAL_EXPORT BaseParameter :
public optimization_solver::iterative_solver::interface1::Parameter
101 const sum_of_functions::interface1::BatchPtr &
function,
102 size_t nIterations = 100,
103 double accuracyThreshold = 1.0e-05,
104 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
105 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
106 new data_management::HomogenNumericTable<double>(
107 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
108 size_t batchSize = 1,
111 virtual ~BaseParameter() {}
118 virtual services::Status check()
const;
120 data_management::NumericTablePtr batchIndices;
123 data_management::NumericTablePtr learningRateSequence;
126 engines::EnginePtr engine;
131 template<Method method>
132 struct Parameter :
public BaseParameter {};
142 struct DAAL_EXPORT Parameter<defaultDense> :
public BaseParameter
153 DAAL_DEPRECATED Parameter(
154 const sum_of_functions::interface1::BatchPtr &
function,
155 size_t nIterations = 100,
156 double accuracyThreshold = 1.0e-05,
157 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
158 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
159 new data_management::HomogenNumericTable<double>(
160 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
168 DAAL_DEPRECATED_VIRTUAL
virtual services::Status check()
const;
170 DAAL_DEPRECATED_VIRTUAL
virtual ~Parameter() {}
182 struct DAAL_EXPORT Parameter<miniBatch> :
public BaseParameter
199 DAAL_DEPRECATED Parameter(
200 const sum_of_functions::interface1::BatchPtr &
function,
201 size_t nIterations = 100,
202 double accuracyThreshold = 1.0e-05,
203 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
204 size_t batchSize = 128,
205 data_management::NumericTablePtr conservativeSequence = data_management::NumericTablePtr(
206 new data_management::HomogenNumericTable<double>(
207 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
208 size_t innerNIterations = 5,
209 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
210 new data_management::HomogenNumericTable<double>(
211 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
219 DAAL_DEPRECATED_VIRTUAL
virtual services::Status check()
const;
221 DAAL_DEPRECATED_VIRTUAL
virtual ~Parameter() {}
223 data_management::NumericTablePtr conservativeSequence;
224 size_t innerNIterations;
237 struct DAAL_EXPORT Parameter<momentum> :
public BaseParameter
253 DAAL_DEPRECATED Parameter(
254 const sum_of_functions::interface1::BatchPtr&
function,
255 double momentum = 0.9,
256 size_t nIterations = 100,
257 double accuracyThreshold = 1.0e-05,
258 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
259 size_t batchSize = 128,
260 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
261 new data_management::HomogenNumericTable<double>(
262 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
270 DAAL_DEPRECATED_VIRTUAL
virtual services::Status check()
const;
272 DAAL_DEPRECATED_VIRTUAL
virtual ~Parameter() {}
286 class DAAL_EXPORT Input :
public optimization_solver::iterative_solver::interface1::Input
289 typedef optimization_solver::iterative_solver::interface1::Input super;
290 DAAL_DEPRECATED Input();
291 DAAL_DEPRECATED Input(
const Input& other);
300 DAAL_DEPRECATED data_management::NumericTablePtr
get(OptionalDataId id)
const;
307 DAAL_DEPRECATED
void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
316 DAAL_DEPRECATED
virtual services::Status check(
const daal::algorithms::Parameter *par,
int method)
const DAAL_C11_OVERRIDE;
326 class DAAL_EXPORT Result :
public optimization_solver::iterative_solver::interface1::Result
329 DECLARE_SERIALIZABLE_CAST(Result);
330 typedef optimization_solver::iterative_solver::interface1::Result super;
332 DAAL_DEPRECATED Result() {}
344 template <
typename algorithmFPType>
345 DAAL_EXPORT DAAL_DEPRECATED services::Status allocate(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
const int method);
352 DAAL_DEPRECATED data_management::NumericTablePtr
get(OptionalDataId id)
const;
359 DAAL_DEPRECATED
void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
369 DAAL_DEPRECATED_VIRTUAL
virtual services::Status check(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
370 int method)
const DAAL_C11_OVERRIDE;
373 typedef services::SharedPtr<Result> ResultPtr;
393 struct DAAL_EXPORT BaseParameter :
public optimization_solver::iterative_solver::Parameter
407 const sum_of_functions::BatchPtr &
function,
408 size_t nIterations = 100,
409 double accuracyThreshold = 1.0e-05,
410 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
411 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
412 new data_management::HomogenNumericTable<double>(
413 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
414 size_t batchSize = 1,
417 virtual ~BaseParameter() {}
424 virtual services::Status check()
const;
426 data_management::NumericTablePtr batchIndices;
429 data_management::NumericTablePtr learningRateSequence;
432 engines::EnginePtr engine;
437 template<Method method>
438 struct Parameter :
public BaseParameter {};
448 struct DAAL_EXPORT Parameter<defaultDense> :
public BaseParameter
460 const sum_of_functions::BatchPtr &
function,
461 size_t nIterations = 100,
462 double accuracyThreshold = 1.0e-05,
463 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
464 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
465 new data_management::HomogenNumericTable<double>(
466 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
474 virtual services::Status check()
const;
476 virtual ~Parameter() {}
488 struct DAAL_EXPORT Parameter<miniBatch> :
public BaseParameter
506 const sum_of_functions::BatchPtr &
function,
507 size_t nIterations = 100,
508 double accuracyThreshold = 1.0e-05,
509 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
510 size_t batchSize = 128,
511 data_management::NumericTablePtr conservativeSequence = data_management::NumericTablePtr(
512 new data_management::HomogenNumericTable<double>(
513 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
514 size_t innerNIterations = 5,
515 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
516 new data_management::HomogenNumericTable<double>(
517 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
525 virtual services::Status check()
const;
527 virtual ~Parameter() {}
529 data_management::NumericTablePtr conservativeSequence;
530 size_t innerNIterations;
543 struct DAAL_EXPORT Parameter<momentum> :
public BaseParameter
560 const sum_of_functions::BatchPtr&
function,
561 double momentum = 0.9,
562 size_t nIterations = 100,
563 double accuracyThreshold = 1.0e-05,
564 data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
565 size_t batchSize = 128,
566 data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
567 new data_management::HomogenNumericTable<double>(
568 1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
576 virtual services::Status check()
const;
578 virtual ~Parameter() {}
592 class DAAL_EXPORT Input :
public optimization_solver::iterative_solver::Input
595 typedef optimization_solver::iterative_solver::Input super;
597 Input(
const Input& other);
606 data_management::NumericTablePtr
get(OptionalDataId id)
const;
613 void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
622 virtual services::Status check(
const daal::algorithms::Parameter *par,
int method)
const DAAL_C11_OVERRIDE;
632 class DAAL_EXPORT Result :
public optimization_solver::iterative_solver::Result
635 DECLARE_SERIALIZABLE_CAST(Result);
636 typedef optimization_solver::iterative_solver::Result super;
650 template <
typename algorithmFPType>
651 DAAL_EXPORT services::Status allocate(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
const int method);
658 data_management::NumericTablePtr
get(OptionalDataId id)
const;
665 void set(OptionalDataId
id,
const data_management::NumericTablePtr &ptr);
675 virtual services::Status check(
const daal::algorithms::Input *input,
const daal::algorithms::Parameter *par,
676 int method)
const DAAL_C11_OVERRIDE;
679 typedef services::SharedPtr<Result> ResultPtr;
684 using interface2::BaseParameter;
685 using interface2::Parameter;
686 using interface2::Input;
687 using interface2::Result;
688 using interface2::ResultPtr;
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter::engine
engines::EnginePtr engine
Definition: sgd_types.h:432
daal::algorithms::optimization_solver::sgd::momentum
Definition: sgd_types.h:60
daal::algorithms::optimization_solver::sgd::interface2::Result
Results obtained with the compute() method of the sgd algorithm in the batch processing mode...
Definition: sgd_types.h:632
daal::algorithms::optimization_solver::sgd::Method
Method
Definition: sgd_types.h:56
daal::algorithms::optimization_solver::sgd::pastUpdateVector
Definition: sgd_types.h:69
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::learningRateSequence
data_management::NumericTablePtr learningRateSequence
Definition: sgd_types.h:123
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::engine
engines::EnginePtr engine
Definition: sgd_types.h:126
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter::learningRateSequence
data_management::NumericTablePtr learningRateSequence
Definition: sgd_types.h:429
daal::algorithms::optimization_solver::sgd::interface1::Parameter< momentum >::momentum
double momentum
Definition: sgd_types.h:274
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::seed
size_t seed
Definition: sgd_types.h:124
daal::algorithms::optimization_solver::sgd::interface2::Parameter
Definition: sgd_types.h:438
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter::batchIndices
data_management::NumericTablePtr batchIndices
Definition: sgd_types.h:426
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter
BaseParameter base class for the Stochastic gradient descent algorithm
Definition: sgd_types.h:87
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter::seed
size_t seed
Definition: sgd_types.h:430
daal::algorithms::optimization_solver::sgd::pastWorkValue
Definition: sgd_types.h:70
daal::algorithms::optimization_solver::sgd::defaultDense
Definition: sgd_types.h:58
daal::algorithms::optimization_solver::sgd::miniBatch
Definition: sgd_types.h:59
daal::algorithms::optimization_solver::sgd::OptionalDataId
OptionalDataId
Definition: sgd_types.h:67
daal::algorithms::optimization_solver::sgd::interface2::Parameter< miniBatch >::conservativeSequence
data_management::NumericTablePtr conservativeSequence
Definition: sgd_types.h:529
daal::algorithms::optimization_solver::sgd::interface2::Parameter< momentum >::momentum
double momentum
Definition: sgd_types.h:580
daal::algorithms::optimization_solver::sgd::interface2::Input
Input for the Stochastic gradient descent algorithm
Definition: sgd_types.h:592
daal::algorithms::optimization_solver::sgd::interface1::Parameter< miniBatch >::conservativeSequence
data_management::NumericTablePtr conservativeSequence
Definition: sgd_types.h:223
daal::algorithms::optimization_solver::sgd::interface1::Input
Input for the Stochastic gradient descent algorithm
Definition: sgd_types.h:286
daal::algorithms::optimization_solver::iterative_solver::interface1::Input
Input parameters for the iterative solver algorithm
Definition: iterative_solver_types.h:160
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::batchIndices
data_management::NumericTablePtr batchIndices
Definition: sgd_types.h:120
daal::algorithms::optimization_solver::sgd::interface1::Result
Results obtained with the compute() method of the sgd algorithm in the batch processing mode...
Definition: sgd_types.h:326
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter
BaseParameter base class for the Stochastic gradient descent algorithm
Definition: sgd_types.h:393
daal::algorithms::optimization_solver::iterative_solver::interface1::Result
Results obtained with the compute() method of the iterative solver algorithm in the batch processing ...
Definition: iterative_solver_types.h:223
daal::algorithms::optimization_solver::sgd::interface1::Parameter
Definition: sgd_types.h:132
daal::algorithms::optimization_solver::iterative_solver::interface1::Parameter
Parameter base class for the iterative solver algorithm
Definition: iterative_solver_types.h:115
daal::algorithms::em_gmm::nIterations
Definition: em_gmm_types.h:99