Régression logistique à l'aide de pytorch

Régression logistique à l'aide de pytorch
La régression logistique est un algorithme d'apprentissage automatique bien connu qui est utilisé pour résoudre des problèmes de classification binaire. Il est dérivé de l'algorithme de régression linéaire, qui a une variable de sortie continue, et la régression logistique peut même classer plus de deux classes en la modifiant légèrement. Nous examinerons le concept de régression logistique et comment il est mis en œuvre dans Pytorch, une bibliothèque utile pour créer des modèles d'apprentissage automatique et d'apprentissage en profondeur.

Concept de régression logistique

La régression logistique est un algorithme de classification binaire. Il s'agit d'un algorithme de prise de décision, ce qui signifie qu'il crée des limites entre deux classes. Il étend le problème de régression linéaire qui utilise un fonction d'activation sur ses sorties pour la limiter entre 1 et 0. En conséquence, cela est utilisé pour les problèmes de classification binaire. Le graphique de la régression logistique ressemble à la figure ci-dessous:

Nous pouvons voir que le graphique est limité entre 0 et 1. La régression linéaire normale peut donner la valeur cible comme n'importe quel nombre réel, mais ce n'est pas le cas avec la régression logistique due à la fonction sigmoïde. La régression logistique est basée sur le concept d'estimation du maximum de vraisemblance (MLE). Le maximum de vraisemblance est simplement de prendre une distribution de probabilité avec un ensemble donné de paramètres et de demander: «Quelle est la probabilité que je verrais ces données si mes données étaient générées à partir de cette distribution de probabilité?"Cela fonctionne en calculant la probabilité pour chaque point de données individuel, puis en multipliant toutes ces probabilités ensemble. Dans la pratique, nous ajoutons les logarithmes des probabilités.

Si nous devons construire un modèle d'apprentissage automatique, chaque point de données variable indépendant sera x1 * W1 + X2 * W2… et ainsi de suite, ce qui donne une valeur entre 0 et 1 lorsqu'il est passé par la fonction d'activation. Si nous prenons 0.50 en tant que facteur ou seuil décisif. Ensuite, tout résultat supérieur à 0.5 est considéré comme un 1, tandis que tout résultat inférieur à cela est considéré comme un 0.

Pour plus de 2 classes, nous utilisons l'approche One-VS-All. One-VS-All, également connu sous le nom de One-VS-Rest, est un processus de classification ML Multilabel et Multiclass ML. Il fonctionne en entraînant d'abord un classificateur binaire pour chaque catégorie, puis en montrant chaque classificateur à chaque entrée pour déterminer la classe à laquelle l'entrée appartient. Si votre problème a N classes, One-VS-All convertira votre ensemble de données de formation en problèmes de classification binaire.

La fonction de perte associée à la régression logistique est Entropie croisée binaire qui est l'inverse du gain d'information. Ceci est également connu comme le nom perte. La fonction de perte est donnée par l'équation:

Quelle est la fonction de perte?

Une fonction de perte est une métrique mathématique que nous voulons réduire. Nous voulons construire un modèle qui peut prédire avec précision ce que nous voulons, et une façon de mesurer les performances du modèle est d'examiner la perte car nous savons ce que le modèle produit et ce que nous devrions obtenir. Nous pouvons nous entraîner et améliorer notre modèle en utilisant cette perte et en ajustant les paramètres du modèle en conséquence. Les fonctions de perte varient en fonction du type d'algorithme. Pour la régression linéaire, l'erreur quadratique moyenne et l'erreur absolue moyenne sont des fonctions de perte populaires, tandis que la transformation croisée est appropriée pour les problèmes de classification.

Quelle est la fonction d'activation?

Les fonctions d'activation sont simplement des fonctions mathématiques qui modifient la variable d'entrée pour donner une nouvelle sortie. Cela se fait généralement dans l'apprentissage automatique pour standardiser les données ou restreindre l'entrée à une certaine limite. Les fonctions d'action populaires sont sigmoïdes, unité linéaire rectifiée (relu), tan (h), etc.

Qu'est-ce que Pytorch?

Pytorch est une alternative populaire en profondeur qui fonctionne avec Torch. Il a été créé par le département d'IA de Facebook, mais il peut être utilisé de la même manière que d'autres options. Il est utilisé pour développer une variété de modèles, mais il est le plus largement appliqué dans les cas d'utilisation du traitement du langage naturel (NLP). Pytorch est toujours une excellente option si vous souhaitez créer des modèles avec très peu de ressources et que vous souhaitez une bibliothèque conviviale, facile à utiliser et légère pour vos modèles. Il semble également naturel, ce qui aide à l'achèvement du processus. Nous utiliserons Pytorch pour la mise en œuvre de nos modèles pour les raisons mentionnées. Cependant, l'algorithme reste le même avec d'autres alternatives comme TensorFlow.

Implémentation de régression logistique dans Pytorch

Nous utiliserons les étapes ci-dessous pour implémenter notre modèle:

  1. Créez un réseau neuronal avec certains paramètres qui seront mis à jour après chaque itération.
  2. Itérer à travers les données d'entrée données.
  3. L'entrée passera à travers le réseau en utilisant la propagation vers l'avant.
  4. Nous calculons maintenant la perte en utilisant une entropie croisée binaire.
  5. Pour minimiser la fonction de coût, nous mettons à jour les paramètres à l'aide de la descente de gradient.
  6. Faites à nouveau les mêmes étapes en utilisant les paramètres mis à jour.

Nous classerons le Ensemble de données MNIST chiffres. Ceci est un problème populaire d'apprentissage en profondeur enseigné aux débutants.

Importons d'abord les bibliothèques et modules requis.

Importer une torche
de la torche.Variable d'importation Autograd
Importer TorchVision.se transforme en transformations
Importer TorchVision.ensembles de données comme DSET

L'étape suivante consiste à importer l'ensemble de données.

Train = DSET.Mnist (root = './ data ', train = true, transform = transforms.Totensor (), télécharger = false)
test = DSET.Mnist (root = './ data ', train = false, transform = transforts.Totensor ())

Utilisez le chargeur de données pour rendre vos données itératives

train_loader = torche.utils.données.DatalOader (DataSet = Train, Batch_Size = Batch_Size, Shuffle = True)
test_loader = torche.utils.données.DatalOader (DataSet = Test, Batch_Size = Batch_Size, Shuffle = False)

Définir le modèle.

Modèle de classe (torche.nn.Module):
def __init __ (self, inp, out):
Super (modèle, soi).__init __ ()
soi.linéaire = torche.nn.Linéaire (inp, out)
Def en avant (self, x):
sorties = soi.linéaire (x)
Sorties de retour

Spécifiez les hyperparamètres, l'optimiseur et la perte.

lot = 50
n_iters = 1500
époques = n_iters / (len (train_dataset) / lot)
INP = 784
Out = 10
alpha = 0.001
modèle = logistique (INP, Out)
Perte = torche.nn.Crossentropyloss ()
Optimizer = torche.optimisation.SGD (modèle.paramètres (), lr = alpha)

Entraîner enfin le modèle.

ITR = 0
pour l'époque dans la gamme (int (époques)):
pour i, (images, étiquettes) en énumération (Train_loader):
images = variable (images.View (-1, 28 * 28))
Labels = variable (étiquettes)
optimiseur.zero_grad ()
sorties = modèle (images)
Lossfunc = perte (sorties, étiquettes)
Lossfunc.en arrière ()
optimiseur.marcher()
itr + = 1
Si ITR% 500 == 0:
Correct = 0
Total = 0
Pour les images, étiquettes dans test_loader:
images = variable (images.View (-1, 28 * 28))
sorties = modèle (images)
_, prédit = torche.Max (sorties.données, 1)
Total + = étiquettes.taille (0)
Correct + = (étiquettes prévues ==).somme()
précision = 100 * correct / total
print ("itération est . La perte est . La précision est .".format (ITR, pertefunc.item (), précision))

Conclusion

Nous avons subi l'explication de la régression logistique et de sa mise en œuvre à l'aide de Pytorch, qui est une bibliothèque populaire pour développer des modèles d'apprentissage en profondeur. Nous avons implémenté le problème de classification de l'ensemble de données MNIST où nous avons reconnu les chiffres en fonction des paramètres d'images.