C++ API Reference for Intel® Data Analytics Acceleration Library 2020 Update 1

sgd_types.h
1 /* file: sgd_types.h */
2 /*******************************************************************************
3 * Copyright 2014-2020 Intel Corporation
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 /*
19 //++
20 // Implementation of the Stochastic gradient descent algorithm types.
21 //--
22 */
23 
24 #ifndef __SGD_TYPES_H__
25 #define __SGD_TYPES_H__
26 
27 #include "algorithms/algorithm.h"
28 #include "data_management/data/numeric_table.h"
29 #include "data_management/data/homogen_numeric_table.h"
30 #include "services/daal_defines.h"
31 #include "algorithms/optimization_solver/iterative_solver/iterative_solver_types.h"
32 #include "algorithms/engines/mt19937/mt19937.h"
33 
34 namespace daal
35 {
36 namespace algorithms
37 {
38 namespace optimization_solver
39 {
49 namespace sgd
50 {
51 
56 enum Method
57 {
58  defaultDense = 0,
59  miniBatch = 1,
60  momentum = 2
61 };
62 
67 enum OptionalDataId
68 {
69  pastUpdateVector = iterative_solver::lastOptionalData + 1,
70  pastWorkValue = pastUpdateVector + 1 ,
71  lastOptionalData = pastWorkValue
72 };
73 
77 namespace interface1
78 {
79 
86 /* [interface1::BaseParameter source code] */
87 struct DAAL_EXPORT BaseParameter : public optimization_solver::iterative_solver::interface1::Parameter
88 {
100  BaseParameter(
101  const sum_of_functions::interface1::BatchPtr &function,
102  size_t nIterations = 100,
103  double accuracyThreshold = 1.0e-05,
104  data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
105  data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
106  new data_management::HomogenNumericTable<double>(
107  1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
108  size_t batchSize = 1,
109  size_t seed = 777 );
110 
111  virtual ~BaseParameter() {}
112 
118  virtual services::Status check() const;
119 
120  data_management::NumericTablePtr batchIndices;
123  data_management::NumericTablePtr learningRateSequence;
124  size_t seed;
126  engines::EnginePtr engine;
128 };
129 /* [interface1::BaseParameter source code] */
130 
131 template<Method method>
132 struct Parameter : public BaseParameter {};
133 
140 /* [interface1::ParameterDefaultDense source code] */
141 template<>
142 struct DAAL_EXPORT Parameter<defaultDense> : public BaseParameter
143 {
153  DAAL_DEPRECATED Parameter(
154  const sum_of_functions::interface1::BatchPtr &function,
155  size_t nIterations = 100,
156  double accuracyThreshold = 1.0e-05,
157  data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
158  data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
159  new data_management::HomogenNumericTable<double>(
160  1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
161  size_t seed = 777 );
162 
168  DAAL_DEPRECATED_VIRTUAL virtual services::Status check() const;
169 
170  DAAL_DEPRECATED_VIRTUAL virtual ~Parameter() {}
171 };
172 /* [interface1::ParameterDefaultDense source code] */
173 
180 /* [interface1::ParameterMiniBatch source code] */
181 template<>
182 struct DAAL_EXPORT Parameter<miniBatch> : public BaseParameter
183 {
199  DAAL_DEPRECATED Parameter(
200  const sum_of_functions::interface1::BatchPtr &function,
201  size_t nIterations = 100,
202  double accuracyThreshold = 1.0e-05,
203  data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
204  size_t batchSize = 128,
205  data_management::NumericTablePtr conservativeSequence = data_management::NumericTablePtr(
206  new data_management::HomogenNumericTable<double>(
207  1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
208  size_t innerNIterations = 5,
209  data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
210  new data_management::HomogenNumericTable<double>(
211  1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
212  size_t seed = 777 );
213 
219  DAAL_DEPRECATED_VIRTUAL virtual services::Status check() const;
220 
221  DAAL_DEPRECATED_VIRTUAL virtual ~Parameter() {}
222 
223  data_management::NumericTablePtr conservativeSequence;
224  size_t innerNIterations;
225 };
226 /* [interface1::ParameterMiniBatch source code] */
235 /* [interface1::ParameterMomentum source code] */
236 template<>
237 struct DAAL_EXPORT Parameter<momentum> : public BaseParameter
238 {
253  DAAL_DEPRECATED Parameter(
254  const sum_of_functions::interface1::BatchPtr& function,
255  double momentum = 0.9,
256  size_t nIterations = 100,
257  double accuracyThreshold = 1.0e-05,
258  data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
259  size_t batchSize = 128,
260  data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
261  new data_management::HomogenNumericTable<double>(
262  1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
263  size_t seed = 777 );
264 
270  DAAL_DEPRECATED_VIRTUAL virtual services::Status check() const;
271 
272  DAAL_DEPRECATED_VIRTUAL virtual ~Parameter() {}
273 
274  double momentum;
275 };
276 /* [interface1::ParameterMomentum source code] */
285 /* [interface1::Input source code] */
286 class DAAL_EXPORT Input : public optimization_solver::iterative_solver::interface1::Input
287 {
288 public:
289  typedef optimization_solver::iterative_solver::interface1::Input super;
290  DAAL_DEPRECATED Input();
291  DAAL_DEPRECATED Input(const Input& other);
292  using super::set;
293  using super::get;
294 
300  DAAL_DEPRECATED data_management::NumericTablePtr get(OptionalDataId id) const;
301 
307  DAAL_DEPRECATED void set(OptionalDataId id, const data_management::NumericTablePtr &ptr);
308 
316  DAAL_DEPRECATED virtual services::Status check(const daal::algorithms::Parameter *par, int method) const DAAL_C11_OVERRIDE;
317 };
318 /* [interface1::Input source code] */
325 /* [interface1::Result source code] */
326 class DAAL_EXPORT Result : public optimization_solver::iterative_solver::interface1::Result
327 {
328 public:
329  DECLARE_SERIALIZABLE_CAST(Result);
330  typedef optimization_solver::iterative_solver::interface1::Result super;
331 
332  DAAL_DEPRECATED Result() {}
333  using super::set;
334  using super::get;
335 
344  template <typename algorithmFPType>
345  DAAL_EXPORT DAAL_DEPRECATED services::Status allocate(const daal::algorithms::Input *input, const daal::algorithms::Parameter *par, const int method);
346 
352  DAAL_DEPRECATED data_management::NumericTablePtr get(OptionalDataId id) const;
353 
359  DAAL_DEPRECATED void set(OptionalDataId id, const data_management::NumericTablePtr &ptr);
360 
369  DAAL_DEPRECATED_VIRTUAL virtual services::Status check(const daal::algorithms::Input *input, const daal::algorithms::Parameter *par,
370  int method) const DAAL_C11_OVERRIDE;
371 
372 };
373 typedef services::SharedPtr<Result> ResultPtr;
374 /* [interface1::Result source code] */
377 } // namespace interface1
378 
379 
383 namespace interface2
384 {
385 
392 /* [BaseParameter source code] */
393 struct DAAL_EXPORT BaseParameter : public optimization_solver::iterative_solver::Parameter
394 {
406  BaseParameter(
407  const sum_of_functions::BatchPtr &function,
408  size_t nIterations = 100,
409  double accuracyThreshold = 1.0e-05,
410  data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
411  data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
412  new data_management::HomogenNumericTable<double>(
413  1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
414  size_t batchSize = 1,
415  size_t seed = 777 );
416 
417  virtual ~BaseParameter() {}
418 
424  virtual services::Status check() const;
425 
426  data_management::NumericTablePtr batchIndices;
429  data_management::NumericTablePtr learningRateSequence;
430  size_t seed;
432  engines::EnginePtr engine;
434 };
435 /* [BaseParameter source code] */
436 
437 template<Method method>
438 struct Parameter : public BaseParameter {};
439 
446 /* [ParameterDefaultDense source code] */
447 template<>
448 struct DAAL_EXPORT Parameter<defaultDense> : public BaseParameter
449 {
459  Parameter(
460  const sum_of_functions::BatchPtr &function,
461  size_t nIterations = 100,
462  double accuracyThreshold = 1.0e-05,
463  data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
464  data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
465  new data_management::HomogenNumericTable<double>(
466  1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
467  size_t seed = 777 );
468 
474  virtual services::Status check() const;
475 
476  virtual ~Parameter() {}
477 };
478 /* [ParameterDefaultDense source code] */
479 
486 /* [ParameterMiniBatch source code] */
487 template<>
488 struct DAAL_EXPORT Parameter<miniBatch> : public BaseParameter
489 {
505  Parameter(
506  const sum_of_functions::BatchPtr &function,
507  size_t nIterations = 100,
508  double accuracyThreshold = 1.0e-05,
509  data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
510  size_t batchSize = 128,
511  data_management::NumericTablePtr conservativeSequence = data_management::NumericTablePtr(
512  new data_management::HomogenNumericTable<double>(
513  1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
514  size_t innerNIterations = 5,
515  data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
516  new data_management::HomogenNumericTable<double>(
517  1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
518  size_t seed = 777 );
519 
525  virtual services::Status check() const;
526 
527  virtual ~Parameter() {}
528 
529  data_management::NumericTablePtr conservativeSequence;
530  size_t innerNIterations;
531 };
532 /* [ParameterMiniBatch source code] */
541 /* [ParameterMomentum source code] */
542 template<>
543 struct DAAL_EXPORT Parameter<momentum> : public BaseParameter
544 {
559  Parameter(
560  const sum_of_functions::BatchPtr& function,
561  double momentum = 0.9,
562  size_t nIterations = 100,
563  double accuracyThreshold = 1.0e-05,
564  data_management::NumericTablePtr batchIndices = data_management::NumericTablePtr(),
565  size_t batchSize = 128,
566  data_management::NumericTablePtr learningRateSequence = data_management::NumericTablePtr(
567  new data_management::HomogenNumericTable<double>(
568  1, 1, data_management::NumericTableIface::doAllocate, 1.0)),
569  size_t seed = 777 );
570 
576  virtual services::Status check() const;
577 
578  virtual ~Parameter() {}
579 
580  double momentum;
581 };
582 /* [ParameterMomentum source code] */
591 /* [Input source code] */
592 class DAAL_EXPORT Input : public optimization_solver::iterative_solver::Input
593 {
594 public:
595  typedef optimization_solver::iterative_solver::Input super;
596  Input();
597  Input(const Input& other);
598  using super::set;
599  using super::get;
600 
606  data_management::NumericTablePtr get(OptionalDataId id) const;
607 
613  void set(OptionalDataId id, const data_management::NumericTablePtr &ptr);
614 
622  virtual services::Status check(const daal::algorithms::Parameter *par, int method) const DAAL_C11_OVERRIDE;
623 };
624 /* [Input source code] */
631 /* [Result source code] */
632 class DAAL_EXPORT Result : public optimization_solver::iterative_solver::Result
633 {
634 public:
635  DECLARE_SERIALIZABLE_CAST(Result);
636  typedef optimization_solver::iterative_solver::Result super;
637 
638  Result() {}
639  using super::set;
640  using super::get;
641 
650  template <typename algorithmFPType>
651  DAAL_EXPORT services::Status allocate(const daal::algorithms::Input *input, const daal::algorithms::Parameter *par, const int method);
652 
658  data_management::NumericTablePtr get(OptionalDataId id) const;
659 
665  void set(OptionalDataId id, const data_management::NumericTablePtr &ptr);
666 
675  virtual services::Status check(const daal::algorithms::Input *input, const daal::algorithms::Parameter *par,
676  int method) const DAAL_C11_OVERRIDE;
677 
678 };
679 typedef services::SharedPtr<Result> ResultPtr;
680 /* [Result source code] */
683 } // namespace interface2
684 using interface2::BaseParameter;
685 using interface2::Parameter;
686 using interface2::Input;
687 using interface2::Result;
688 using interface2::ResultPtr;
689 
690 } // namespace sgd
691 } // namespace optimization_solver
692 } // namespace algorithm
693 } // namespace daal
694 #endif
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter::engine
engines::EnginePtr engine
Definition: sgd_types.h:432
daal::algorithms::optimization_solver::sgd::momentum
Definition: sgd_types.h:60
daal::algorithms::optimization_solver::sgd::interface2::Result
Results obtained with the compute() method of the sgd algorithm in the batch processing mode...
Definition: sgd_types.h:632
daal::algorithms::optimization_solver::sgd::Method
Method
Definition: sgd_types.h:56
daal::algorithms::optimization_solver::sgd::pastUpdateVector
Definition: sgd_types.h:69
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::learningRateSequence
data_management::NumericTablePtr learningRateSequence
Definition: sgd_types.h:123
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::engine
engines::EnginePtr engine
Definition: sgd_types.h:126
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter::learningRateSequence
data_management::NumericTablePtr learningRateSequence
Definition: sgd_types.h:429
daal::algorithms::optimization_solver::sgd::interface1::Parameter< momentum >::momentum
double momentum
Definition: sgd_types.h:274
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::seed
size_t seed
Definition: sgd_types.h:124
daal::algorithms::optimization_solver::sgd::interface2::Parameter
Definition: sgd_types.h:438
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter::batchIndices
data_management::NumericTablePtr batchIndices
Definition: sgd_types.h:426
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter
BaseParameter base class for the Stochastic gradient descent algorithm
Definition: sgd_types.h:87
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter::seed
size_t seed
Definition: sgd_types.h:430
daal::algorithms::optimization_solver::sgd::pastWorkValue
Definition: sgd_types.h:70
daal::algorithms::optimization_solver::sgd::defaultDense
Definition: sgd_types.h:58
daal::algorithms::optimization_solver::sgd::miniBatch
Definition: sgd_types.h:59
daal::algorithms::optimization_solver::sgd::OptionalDataId
OptionalDataId
Definition: sgd_types.h:67
daal_defines.h
daal::algorithms::optimization_solver::sgd::interface2::Parameter< miniBatch >::conservativeSequence
data_management::NumericTablePtr conservativeSequence
Definition: sgd_types.h:529
daal::algorithms::optimization_solver::sgd::interface2::Parameter< momentum >::momentum
double momentum
Definition: sgd_types.h:580
daal::algorithms::optimization_solver::sgd::interface2::Input
Input for the Stochastic gradient descent algorithm
Definition: sgd_types.h:592
daal::algorithms::optimization_solver::sgd::interface1::Parameter< miniBatch >::conservativeSequence
data_management::NumericTablePtr conservativeSequence
Definition: sgd_types.h:223
daal::algorithms::optimization_solver::sgd::interface1::Input
Input for the Stochastic gradient descent algorithm
Definition: sgd_types.h:286
daal::algorithms::optimization_solver::iterative_solver::interface1::Input
Input parameters for the iterative solver algorithm
Definition: iterative_solver_types.h:160
daal::algorithms::optimization_solver::sgd::interface1::BaseParameter::batchIndices
data_management::NumericTablePtr batchIndices
Definition: sgd_types.h:120
daal::algorithms::optimization_solver::sgd::interface1::Result
Results obtained with the compute() method of the sgd algorithm in the batch processing mode...
Definition: sgd_types.h:326
daal::algorithms::optimization_solver::sgd::interface2::BaseParameter
BaseParameter base class for the Stochastic gradient descent algorithm
Definition: sgd_types.h:393
daal::algorithms::optimization_solver::iterative_solver::interface1::Result
Results obtained with the compute() method of the iterative solver algorithm in the batch processing ...
Definition: iterative_solver_types.h:223
daal::algorithms::optimization_solver::sgd::interface1::Parameter
Definition: sgd_types.h:132
daal::algorithms::optimization_solver::iterative_solver::interface1::Parameter
Parameter base class for the iterative solver algorithm
Definition: iterative_solver_types.h:115
daal::algorithms::em_gmm::nIterations
Definition: em_gmm_types.h:99

For more complete information about compiler optimizations, see our Optimization Notice.