24 #ifndef __NEURAL_NETWORK_TRAINING_MODEL_H__
25 #define __NEURAL_NETWORK_TRAINING_MODEL_H__
27 #include "services/daal_defines.h"
28 #include "data_management/data/tensor.h"
29 #include "data_management/data/numeric_table.h"
30 #include "services/daal_memory.h"
31 #include "algorithms/neural_networks/layers/layer.h"
32 #include "algorithms/neural_networks/layers/layer_types.h"
33 #include "algorithms/neural_networks/layers/loss/loss_layer_forward.h"
34 #include "algorithms/neural_networks/layers/split/split_layer_forward.h"
35 #include "algorithms/neural_networks/neural_networks_prediction_model.h"
36 #include "algorithms/neural_networks/neural_networks_training_topology.h"
38 #include "algorithms/optimization_solver/iterative_solver/iterative_solver_batch.h"
47 namespace neural_networks
62 class Parameter :
public daal::algorithms::Parameter
71 Parameter(
const services::SharedPtr<optimization_solver::iterative_solver::Batch > &optimizationSolver_ = services::SharedPtr<optimization_solver::iterative_solver::Batch>(),
72 engines::EnginePtr engine_ = engines::mt19937::Batch<DAAL_ALGORITHM_FP_TYPE>::create()) :
73 optimizationSolver(optimizationSolver_),
76 services::SharedPtr<optimization_solver::iterative_solver::Batch> optimizationSolver;
77 engines::EnginePtr engine;
85 class DAAL_EXPORT Model :
public neural_networks::ModelImpl
88 DECLARE_SERIALIZABLE_CAST(Model);
90 using neural_networks::ModelImpl::getWeightsAndBiases;
91 using neural_networks::ModelImpl::setWeightsAndBiases;
97 DAAL_DEPRECATED Model();
102 DAAL_DEPRECATED
static services::SharedPtr<Model> create(services::Status *stat = NULL);
108 DAAL_DEPRECATED Model(
const Model &model) :
110 _backwardLayers(model.getBackwardLayers()),
111 _storeWeightDerivativesInTable(model._storeWeightDerivativesInTable)
126 template<
typename modelFPType>
127 services::Status initialize(
const services::Collection<size_t> &sampleSize,
const Topology &topology,
128 const Parameter ¶meter = Parameter())
130 using namespace layers;
131 using namespace services;
133 size_t nLayers = topology.size();
135 _backwardNextLayers = SharedPtr<Collection<NextLayers> >(
new Collection<NextLayers>(nLayers));
136 if (!_backwardNextLayers)
138 st.add(services::ErrorMemoryAllocationFailed);
142 for(
size_t i = 0; i < nLayers; i++)
144 insertLayer(topology[i]);
147 for(
int i = (
int)nLayers - 1; i >= 0; i--)
149 size_t layerId = topology[i].index();
150 const NextLayers &next = topology[i].nextLayers();
151 for (
size_t j = 0; j < next.size(); j++)
153 (*_backwardNextLayers)[next[j]].push_back(layerId);
157 for(
int i = (
int)nLayers - 1; i >= 0; i--)
159 layers::forward::LayerIfacePtr layer = getForwardLayer(i);
160 SharedPtr<split::forward::Batch<float> > splitLayerFloat = dynamicPointerCast<split::forward::Batch<float>, forward::LayerIface>(layer);
161 SharedPtr<split::forward::Batch<double> > splitLayerDouble = dynamicPointerCast<split::forward::Batch<double>, forward::LayerIface>(layer);
162 if(splitLayerFloat.get() || splitLayerDouble.get())
164 const NextLayers &next = topology[i].nextLayers();
165 for (
size_t j = 0; j < next.size(); j++)
167 layers::forward::LayerIfacePtr nextLayer = getForwardLayer(next[j]);
168 nextLayer->getLayerParameter()->allowInplaceComputation =
false;
173 allocate<modelFPType>(sampleSize, parameter);
175 for(
size_t i = 0; i < nLayers; i++)
177 getForwardLayer(i)->enableResetOnCompute(
false);
178 getBackwardLayer(i)->enableResetOnCompute(
false);
188 DAAL_DEPRECATED
const ForwardLayersPtr getForwardLayers()
const
190 return _forwardLayers;
199 const layers::forward::LayerIfacePtr getForwardLayer(
size_t index)
const
201 return _forwardLayers->get(index);
209 const BackwardLayersPtr getBackwardLayers()
const
211 return _backwardLayers;
220 const layers::backward::LayerIfacePtr getBackwardLayer(
size_t index)
const
222 return _backwardLayers->get(index);
230 template<
typename modelFPType>
231 const prediction::ModelPtr getPredictionModel()
233 using namespace services;
234 using namespace data_management;
235 using namespace layers;
237 size_t nLayers = _forwardLayers->size();
240 ForwardLayersPtr _predictionForwardLayers(
new ForwardLayers(nLayers));
241 SharedPtr<Collection<NextLayers> > _predictionNextLayers(
new Collection<NextLayers>(nLayers));
242 for (
size_t i = 0; i < nLayers; i++)
244 (*_predictionNextLayers)[i] = _nextLayers->get(i);
245 (*_predictionForwardLayers)[i] = ((*_forwardLayers)[i])->getLayerForPrediction();
246 (*_predictionForwardLayers)[i]->getLayerParameter()->predictionStage =
true;
249 bool storeWeightsInTable =
true;
250 prediction::ModelPtr predictionModel(
new prediction::Model(
251 _predictionForwardLayers, _predictionNextLayers, (modelFPType)0.0, storeWeightsInTable));
253 predictionModel->setWeightsAndBiases(getWeightsAndBiases());
254 return predictionModel;
263 DAAL_DEPRECATED
bool getWeightsAndBiasesStorageStatus()
const
265 return _storeWeightsInTable;
276 DAAL_DEPRECATED services::Status setWeightsAndBiases(
size_t idx,
const data_management::NumericTablePtr &table);
284 DAAL_DEPRECATED data_management::NumericTablePtr getWeightsAndBiases(
size_t idx)
const;
291 DAAL_DEPRECATED data_management::NumericTablePtr getWeightsAndBiasesDerivatives()
const;
299 DAAL_DEPRECATED data_management::NumericTablePtr getWeightsAndBiasesDerivatives(
size_t idx)
const;
309 DAAL_DEPRECATED services::Status setErrors(services::ErrorCollection &errors)
311 return services::Status();
320 DAAL_DEPRECATED
const services::ErrorCollection &getErrors()
const {
return _errors; }
330 template<
typename modelFPType>
331 services::Status allocate(
const services::Collection<size_t> &sampleSize,
const Parameter ¶meter = Parameter())
333 using namespace services;
334 using namespace data_management;
335 using namespace layers;
339 if (_sampleSize.size() > 0) { _sampleSize.clear(); }
340 _sampleSize = sampleSize;
342 _forwardLayers->get(0)->getLayerInput()->set(forward::data,
343 TensorPtr(
new HomogenTensor<modelFPType>(_sampleSize, Tensor::doAllocate)));
345 size_t nLayers = _forwardLayers->size();
347 for (
size_t i = 0; i < nLayers; i++)
349 layers::Parameter *lParameter = _forwardLayers->get(i)->getLayerParameter();
350 initializers::Parameter *wParameter = lParameter->weightsInitializer->getParameter();
351 initializers::Parameter *bParameter = lParameter->biasesInitializer->getParameter();
353 s |= connectForwardLayers(i);
355 if(!wParameter->engine)
357 wParameter->engine = parameter.engine;
359 if(!bParameter->engine)
361 bParameter->engine = parameter.engine;
365 bool checkWeightsAndBiasesAlloc =
true;
366 s |= createWeightsAndBiases<modelFPType>(checkWeightsAndBiasesAlloc);
367 s |= enableConditionalGradientPropagation();
370 for (
size_t i = 0; i < nLayers; i++)
372 forward::LayerIfacePtr forwardLayer = _forwardLayers->get(i);
373 forward::Input *forwardInput = forwardLayer->getLayerInput();
375 forwardLayer->getLayerResult()->setResultForBackward(forwardInput);
379 s |= checkWeightsAndBiasesDerivativesAllocation();
381 for (
int i = (
int)nLayers - 1; i >= 0; i--)
383 s |= connectBackwardLayers(i);
386 s |= createWeightsAndBiasesDerivatives<modelFPType>();
387 if(_solverOptionalArgumentCollection.size() == 0)
389 if(_storeWeightsInTable) _solverOptionalArgumentCollection = DataCollection(1);
390 else _solverOptionalArgumentCollection = DataCollection(nLayers);
400 DAAL_DEPRECATED Model(services::Status &st);
406 template<
typename Archive,
bool onDeserialize>
407 services::Status serialImpl(Archive *arch)
409 return services::Status();
415 void insertLayer(
const layers::LayerDescriptor &layerDescriptor)
417 _forwardLayers->insert(layerDescriptor.index(), layerDescriptor.layer()->forwardLayer->clone());
418 _backwardLayers->insert(layerDescriptor.index(), layerDescriptor.layer()->backwardLayer->clone());
419 _nextLayers->insert(layerDescriptor.index(), layerDescriptor.nextLayers());
425 services::Status enableConditionalGradientPropagation()
427 using namespace services;
428 using namespace layers;
432 size_t nLayers = _forwardLayers->size();
435 bool *flags = (
bool *)daal_malloc(nLayers *
sizeof(
bool));
439 s |= disableGradientPropagationInStartingLayers(nLayers, flags);
443 s |= enableGradientPropagation(nLayers, flags);
452 services::Status disableGradientPropagationInStartingLayers(
size_t nLayers,
bool *visited)
454 using namespace services;
455 using namespace layers;
457 for (
size_t i = 0; i < nLayers; i++)
462 Collection<size_t> stack;
464 while (stack.size() > 0)
466 size_t layerId = stack[stack.size() - 1];
467 stack.erase(stack.size() - 1);
468 if (!visited[layerId])
470 visited[layerId] =
true;
472 forward::LayerIfacePtr forwardLayer = _forwardLayers->get(layerId);
473 forward::Input *forwardInput = forwardLayer->getLayerInput();
474 layers::Parameter *forwardParameter = forwardLayer->getLayerParameter();
475 layers::Parameter *backwardParameter = _backwardLayers->get(layerId)->getLayerParameter();
477 backwardParameter->propagateGradient =
false;
479 if (forwardInput->getWeightsSizes(forwardParameter).size() +
480 forwardInput->getBiasesSizes(forwardParameter) .size() == 0)
483 const NextLayers &next = _nextLayers->get(layerId);
484 for (
size_t i = 0; i < next.size(); i++)
486 stack.push_back(next[i]);
491 return services::Status();
497 services::Status enableGradientPropagationInSubsequentLayers(
size_t startLayerId,
size_t nLayers,
bool *enabledPropagation)
499 using namespace services;
500 using namespace layers;
501 Collection<size_t> stack;
502 const NextLayers &next = _nextLayers->get(startLayerId);
503 for (
size_t i = 0; i < next.size(); i++)
505 stack.push_back(next[i]);
507 while (stack.size() > 0)
509 size_t layerId = stack[stack.size() - 1];
510 stack.erase(stack.size() - 1);
511 if (!enabledPropagation[layerId])
513 enabledPropagation[layerId] =
true;
514 backward::LayerIfacePtr backwardLayer = _backwardLayers->get(layerId);
515 backwardLayer->getLayerParameter()->propagateGradient =
true;
516 const NextLayers &next = _nextLayers->get(layerId);
517 for (
size_t i = 0; i < next.size(); i++)
519 stack.push_back(next[i]);
523 return services::Status();
529 services::Status enableGradientPropagation(
size_t nLayers,
bool *enabledPropagation)
531 using namespace services;
532 using namespace layers;
533 Collection<size_t> stack;
536 for (
size_t i = 0; i < nLayers; i++)
538 enabledPropagation[i] =
false;
541 while (stack.size() > 0)
543 size_t layerId = stack[stack.size() - 1];
544 stack.erase(stack.size() - 1);
545 if (!enabledPropagation[layerId])
547 forward::LayerIfacePtr forwardLayer = _forwardLayers->get(layerId);
548 forward::Input *forwardInput = forwardLayer->getLayerInput();
549 layers::Parameter *forwardParameter = forwardLayer->getLayerParameter();
550 layers::Parameter *backwardParameter = _backwardLayers->get(layerId)->getLayerParameter();
552 if (backwardParameter->propagateGradient ==
false &&
553 (forwardInput->getWeightsSizes(forwardParameter).size() +
554 forwardInput->getBiasesSizes(forwardParameter) .size()) > 0)
556 enableGradientPropagationInSubsequentLayers(layerId, nLayers, enabledPropagation);
560 const NextLayers &next = _nextLayers->get(layerId);
561 for (
size_t i = 0; i < next.size(); i++)
563 stack.push_back(next[i]);
568 return services::Status();
574 services::Status checkWeightsAndBiasesDerivativesAllocation()
576 using namespace services;
577 using namespace layers;
579 _storeWeightDerivativesInTable =
true;
580 size_t nLayers = _backwardLayers->size();
581 for (
size_t i = 0; i < nLayers; i++)
583 backward::LayerIfacePtr &backwardLayer = _backwardLayers->get(i);
584 if (!backwardLayer) {
continue; }
585 backward::ResultPtr backwardResult = backwardLayer->getLayerResult();
587 if (backwardResult->get(backward::weightDerivatives) || backwardResult->get(backward::biasDerivatives))
589 _storeWeightDerivativesInTable =
false;
593 return services::Status();
599 services::Status connectBackwardLayers(
size_t layerId)
601 using namespace services;
602 using namespace data_management;
603 using namespace layers;
605 forward::LayerIfacePtr &forwardLayer = _forwardLayers->get(layerId);
606 backward::LayerIfacePtr &backwardLayer = _backwardLayers->get(layerId);
608 if (!forwardLayer || !backwardLayer) {
return services::Status(); }
610 backward::Input *backwardInput = backwardLayer->getLayerInput();
611 forward::ResultPtr forwardResult = forwardLayer->getLayerResult();
613 backwardInput->setInputFromForward(forwardResult);
614 backwardLayer->allocateResult();
618 if (!backwardLayer->getLayerParameter()->propagateGradient) {
return services::Status(); }
620 backward::ResultPtr backwardResult = backwardLayer->getLayerResult();
622 const NextLayers &next = _backwardNextLayers->get(layerId);
623 const size_t nextLayersSize = next.size();
624 for(
size_t j = 0; j < nextLayersSize; j++)
626 size_t inputIndex = nextLayersSize - j - 1;
627 _backwardLayers->get(next[j])->addInput(backwardResult, inputIndex, 0 );
629 return services::Status();
635 template<
typename modelFPType>
636 DAAL_EXPORT services::Status createWeightsAndBiasesDerivatives();
645 DAAL_DEPRECATED algorithms::OptionalArgumentPtr getSolverOptionalArgument(
size_t index)
647 return services::dynamicPointerCast<algorithms::OptionalArgument, data_management::SerializationIface>(_solverOptionalArgumentCollection[index]);
658 DAAL_DEPRECATED services::Status setSolverOptionalArgument(
const algorithms::OptionalArgumentPtr& solverOptionalArgument,
size_t index)
660 _solverOptionalArgumentCollection[index] = solverOptionalArgument;
661 return services::Status();
669 DAAL_DEPRECATED data_management::DataCollection getSolverOptionalArgumentCollection()
671 return _solverOptionalArgumentCollection;
681 DAAL_DEPRECATED services::Status setSolverOptionalArgumentCollection(
const data_management::DataCollection &solverOptionalArgumentCollection)
683 _solverOptionalArgumentCollection = solverOptionalArgumentCollection;
684 return services::Status();
688 data_management::DataCollection _solverOptionalArgumentCollection;
689 services::Collection<size_t> _sampleSize;
690 BackwardLayersPtr _backwardLayers;
691 services::SharedPtr<services::Collection<layers::NextLayers> > _backwardNextLayers;
692 mutable services::ErrorCollection _errors;
694 bool _storeWeightDerivativesInTable;
695 LearnableParametersIfacePtr _weightsAndBiasesDerivatives;
698 typedef services::SharedPtr<Model> ModelPtr;
702 using interface1::Parameter;
703 using interface1::Model;
704 using interface1::ModelPtr;
daal::algorithms::neural_networks::training::interface1::Model::getForwardLayer
const layers::forward::LayerIfacePtr getForwardLayer(size_t index) const
Definition: neural_networks_training_model.h:199
daal::algorithms::neural_networks::training::interface1::Model::getSolverOptionalArgumentCollection
DAAL_DEPRECATED data_management::DataCollection getSolverOptionalArgumentCollection()
Definition: neural_networks_training_model.h:669
daal::algorithms::interface1::Model
The base class for the classes that represent the models, such as linear_regression::Model or svm::Mo...
Definition: model.h:54
daal::algorithms::neural_networks::training::interface1::Model::getWeightsAndBiasesStorageStatus
DAAL_DEPRECATED bool getWeightsAndBiasesStorageStatus() const
Definition: neural_networks_training_model.h:263
daal::algorithms::neural_networks::training::interface1::Model::~Model
virtual ~Model()
Destructor.
Definition: neural_networks_training_model.h:115
daal::algorithms::neural_networks::training::interface1::Parameter::engine
engines::EnginePtr engine
Definition: neural_networks_training_model.h:77
daal::algorithms::neural_networks::training::interface1::Model::setSolverOptionalArgument
DAAL_DEPRECATED services::Status setSolverOptionalArgument(const algorithms::OptionalArgumentPtr &solverOptionalArgument, size_t index)
Definition: neural_networks_training_model.h:658
daal::algorithms::neural_networks::training::interface1::Topology::push_back
size_t push_back(const layers::LayerIfacePtr &layer)
Definition: neural_networks_training_topology.h:78
daal::algorithms::neural_networks::training::interface1::Model::initialize
services::Status initialize(const services::Collection< size_t > &sampleSize, const Topology &topology, const Parameter ¶meter=Parameter())
Definition: neural_networks_training_model.h:127
daal::algorithms::association_rules::data
Definition: apriori_types.h:83
daal::algorithms::neural_networks::training::interface1::Model::allocate
services::Status allocate(const services::Collection< size_t > &sampleSize, const Parameter ¶meter=Parameter())
Definition: neural_networks_training_model.h:331
daal::algorithms::neural_networks::training::interface1::Model::Model
DAAL_DEPRECATED Model(const Model &model)
Copy constructor.
Definition: neural_networks_training_model.h:108
daal::algorithms::neural_networks::training::interface1::Model::getForwardLayers
DAAL_DEPRECATED const ForwardLayersPtr getForwardLayers() const
Definition: neural_networks_training_model.h:188
daal::algorithms::neural_networks::training::interface1::Model::getBackwardLayers
const BackwardLayersPtr getBackwardLayers() const
Definition: neural_networks_training_model.h:209
daal::algorithms::neural_networks::training::interface1::Topology::size
size_t size() const
Definition: neural_networks_training_topology.h:70
daal::algorithms::neural_networks::training::interface1::Model::getPredictionModel
const prediction::ModelPtr getPredictionModel()
Definition: neural_networks_training_model.h:231
daal::algorithms::neural_networks::training::interface1::Parameter::Parameter
Parameter(const services::SharedPtr< optimization_solver::iterative_solver::Batch > &optimizationSolver_=services::SharedPtr< optimization_solver::iterative_solver::Batch >(), engines::EnginePtr engine_=engines::mt19937::Batch< DAAL_ALGORITHM_FP_TYPE >::create())
Definition: neural_networks_training_model.h:71
daal::algorithms::neural_networks::training::interface1::Model::getErrors
DAAL_DEPRECATED const services::ErrorCollection & getErrors() const
Definition: neural_networks_training_model.h:320
daal::algorithms::neural_networks::training::interface1::Model::getSolverOptionalArgument
DAAL_DEPRECATED algorithms::OptionalArgumentPtr getSolverOptionalArgument(size_t index)
Definition: neural_networks_training_model.h:645
daal::algorithms::neural_networks::training::interface1::Topology
Class defining a neural network topology - a set of layers and connection between them - on the train...
Definition: neural_networks_training_topology.h:43
daal::algorithms::neural_networks::training::interface1::Model::getBackwardLayer
const layers::backward::LayerIfacePtr getBackwardLayer(size_t index) const
Definition: neural_networks_training_model.h:220
daal::algorithms::interface1::Parameter
Base class to represent computation parameters. Algorithm-specific parameters are represented as deri...
Definition: algorithm_types.h:62
daal::algorithms::neural_networks::training::interface1::Parameter::optimizationSolver
services::SharedPtr< optimization_solver::iterative_solver::Batch > optimizationSolver
Definition: neural_networks_training_model.h:76
daal::services::daal_malloc
DAAL_EXPORT void * daal_malloc(size_t size, size_t alignment=DAAL_MALLOC_DEFAULT_ALIGNMENT)
daal::algorithms::neural_networks::training::interface1::Model
Class representing the model of neural network.
Definition: neural_networks_training_model.h:85
daal::services::daal_free
DAAL_EXPORT void daal_free(void *ptr)
daal::services::ErrorMemoryAllocationFailed
Definition: error_indexes.h:150
daal::algorithms::neural_networks::training::interface1::Model::setErrors
DAAL_DEPRECATED services::Status setErrors(services::ErrorCollection &errors)
Definition: neural_networks_training_model.h:309
daal::algorithms::neural_networks::training::interface1::Model::setSolverOptionalArgumentCollection
DAAL_DEPRECATED services::Status setSolverOptionalArgumentCollection(const data_management::DataCollection &solverOptionalArgumentCollection)
Definition: neural_networks_training_model.h:681
daal::algorithms::neural_networks::training::model
Definition: neural_networks_training_result.h:54
daal::algorithms::neural_networks::training::interface1::Parameter
Class representing the parameters of neural network.
Definition: neural_networks_training_model.h:62