Pytorch - Argmax

Pytorch - Argmax
«Dans ce tutoriel Pytorch, nous verrons comment renvoyer les positions d'index des valeurs maximales d'un tenseur à l'aide d'argmax ().

Pytorch est un cadre open source disponible avec un langage de programmation Python. Nous pouvons traiter les données en pytorch sous la forme d'un tenseur.

Un tenseur est un tableau multidimensionnel utilisé pour stocker les données. Donc, pour utiliser un tenseur, nous devons importer le module de torche.

Pour créer un tenseur, la méthode utilisée est tenseur () »

Syntaxe:

torche.tenseur (données)

Où les données sont un tableau multidimensionnel.

Argmax ()

Argmax () dans Pytorch est utilisé pour renvoyer l'index de la valeur maximale de tous les éléments du tenseur d'entrée.

Syntaxe:

torche.argmax (tenseur, dim, keepdim)

  1. Le tenseur est le tenseur d'entrée
  2. DIM est de réduire la dimension. DIM = 0 spécifie la comparaison de la colonne, qui obtiendra l'index pour une valeur maximale le long d'une colonne, et DIM = 1 spécifie la comparaison des lignes, qui obtiendra l'index pour une valeur maximale le long de la ligne.
  3. KeepDim vérifie si le tenseur de sortie a une dimension (DIM) conservée ou non

Exemple 1

Dans cet exemple, nous créerons un tenseur avec 2 dimensions qui ont 3 lignes et 5 colonnes et appliquer Argmax () sur les lignes et les colonnes.

Module de torche #mport
Importer une torche
#create un tenseur avec 2 dimensions (3 * 5)
# avec des éléments aléatoires en utilisant la fonction randn ()
données = torche.Randn (3,5)
#afficher
Imprimer (données)
#get index maximum le long des colonnes avec argmax
imprimer (torche.argmax (données, dim = 0))
#get index maximum le long des lignes avec argmax
imprimer (torche.argmax (données, dim = 1))

Sortir:

tenseur ([[0.6699, 1.3390, -1.0658, -1.8200, 0.6544],
[-0.3117, 0.2488, 0.2677, 0.2568, 0.5337],
[-1.0966, 1.8024, -0.7538, -0.2553, -1.0591]])
tenseur ([0, 2, 1, 1, 0])
tenseur ([1, 4, 1])

Nous pouvons voir que les valeurs maximales présentes dans l'index le long des colonnes sont:

  1. Valeur maximale - 0.6699. Son index est 0.
  2. Valeur maximale - 1.8024. Son index est 2.
  3. Valeur maximale - 0.2677. Son index est 1.
  4. Valeur maximale - 0.2568. Son index est 1.
  5. Valeur maximale - 0.6544. Son index est 0.

De même, les valeurs maximales présentes à l'index le long des lignes sont:

  1. Valeur maximale - 1.3390. Son index est 1.
  2. Valeur maximale - 0.5337. Son index est 4.
  3. Valeur maximale - 1.8024. Son index est 1.

Exemple 2

Créez du tenseur avec 5 * 5 matrice et appliquez Argmax ()

Module de torche #mport
Importer une torche
#create un tenseur avec 2 dimensions (5 * 5)
# avec des éléments aléatoires en utilisant la fonction randn ()
données = torche.Randn (5,5)
#afficher
Imprimer (données)
#get index maximum le long des colonnes avec argmax
imprimer (torche.argmax (données, dim = 0))
#get index maximum le long des lignes avec argmax
imprimer (torche.argmax (données, dim = 1))

Sortir:

Tensor ([[- 0.9553, -0.2611, -2.1233, -0.5208, -0.3458],
[-0.5466, -1.6395, 0.2576, -0.3123, 0.6785],
[-0.4574, 1.5301, 0.4812, 0.3434, 0.1388],
[0.8364, 0.3821, 0.1529, 1.4529, 0.3747],
[-1.4991, -1.8821, -0.2861, -0.4067, 1.1323]])
tenseur ([3, 2, 2, 3, 4])
tenseur ([1, 4, 1, 3, 4])

Nous pouvons voir que les valeurs maximales présentes dans l'index le long des colonnes sont:

  1. Valeur maximale - 0.8364. Son index est 3.
  2. Valeur maximale - 1.5301. Son index est 2.
  3. Valeur maximale - 0.4812. Son index est 2.
  4. Valeur maximale - 1.4529. Son index est 3.
  5. Valeur maximale - 1.1323. Son index est 4.

De même, les valeurs maximales présentes à l'index le long des lignes sont:

  1. Valeur maximale - -0.2611. Son index est 1.
  2. Valeur maximale - 0.6785. Son index est 4.
  3. Valeur maximale - 1.5301. Son index est 1.
  4. Valeur maximale - 1.4529. Son index est 3.
  5. Valeur maximale - 1.1323. Son index est 4.

Travailler avec le processeur

Si vous souhaitez exécuter une fonction argmax () sur le CPU, alors nous devons créer un tenseur avec une fonction CPU (). Cela fonctionnera sur une machine CPU.

Lorsque nous créons un tenseur, pour le moment, nous pouvons utiliser la fonction CPU ().

Syntaxe:

torche.tenseur (données).CPU()

Exemple 1

Créez du tenseur avec 5 * 5 matrice avec CPU () et appliquez Argmax ()
Module de torche #mport
Importer une torche
#create un tenseur avec 2 dimensions (5 * 5)
# avec des éléments aléatoires en utilisant la fonction randn () avec CPU ()
données = torche.Randn (5,5).CPU()
#afficher
Imprimer (données)
#get index maximum le long des colonnes avec argmax
imprimer (torche.argmax (données, dim = 0))
#get index maximum le long des lignes avec argmax
imprimer (torche.argmax (données, dim = 1))

Sortir:

Tensor ([[- 0.2213, 1.6140, -0.0774, 0.4135, 0.1379],
[-0.4415, -2.5789, 0.8294, -0.9309, 1.3535],
[-1.3256, -0.7233, -0.9713, 1.0742, 1.9350],
[-0.7126, -1.3336, 0.7371, -0.2253, 0.1675],
[-0.1174, -0.5773, 0.8887, -0.2563, -1.0667]])
tenseur ([4, 0, 4, 2, 2])
tenseur ([1, 4, 4, 2, 2])

Nous pouvons voir que les valeurs maximales présentes dans l'index le long des colonnes sont:

  1. Valeur maximale - -0.1174. Son index est 4.
  2. Valeur maximale - 1.6140. Son index est 0.
  3. Valeur maximale - 0.8887. Son index est 4.
  4. Valeur maximale - 1.0742. Son index est 2.
  5. Valeur maximale - 1.9350. Son index est 2.

De même, les valeurs maximales présentes à l'index le long des lignes sont:

  1. Valeur maximale - 1.6140. Son index est 1.
  2. Valeur maximale - 1.3535. Son index est 4.
  3. Valeur maximale - 1.9350. Son index est 4.
  4. Valeur maximale - 0.7371. Son index est 2.
  5. Valeur maximale - 0.8887. Son index est 2.

Exemple 2

Dans cet exemple, nous créerons un tenseur avec 2 dimensions qui ont 3 lignes et 5 colonnes à l'aide de la fonction CPU () et appliquer Argmax () sur les lignes et les colonnes.

Module de torche #mport
Importer une torche
#create un tenseur avec 2 dimensions (3 * 5)
# avec des éléments aléatoires en utilisant randn () avec CPU ()
données = torche.Randn (3,5).CPU()
#afficher
Imprimer (données)
#get index maximum le long des colonnes avec argmax
imprimer (torche.argmax (données, dim = 0))
#get index maximum le long des lignes avec argmax
imprimer (torche.argmax (données, dim = 1))

Sortir:

tenseur ([[0.6699, 1.3390, -1.0658, -1.8200, 0.6544],
[-0.3117, 0.2488, 0.2677, 0.2568, 0.5337],
[-1.0966, 1.8024, -0.7538, -0.2553, -1.0591]])
tenseur ([0, 2, 1, 1, 0])
tenseur ([1, 4, 1])

Nous pouvons voir que les valeurs maximales présentes dans l'index le long des colonnes sont:

  1. Valeur maximale - 0.6699. Son index est 0.
  2. Valeur maximale - 1.8024. Son index est 2.
  3. Valeur maximale - 0.2677. Son index est 1.
  4. Valeur maximale - 0.2568. Son index est 1.
  5. Valeur maximale - 0.6544. Son index est 0.

De même, les valeurs maximales présentes à l'index le long des lignes sont:

  1. Valeur maximale - 1.3390. Son index est 1.
  2. Valeur maximale - 0.5337. Son index est 4.
  3. Valeur maximale - 1.8024. Son index est 1.

Conclusion

Dans cette leçon de pytorch, nous avons vu ce que Argmax () et comment appliquer Argmax () sur un tenseur pour retourner des indices de valeurs maximales entre les colonnes et les lignes.

Nous avons également créé un tenseur avec une fonction CPU () et retourné des indices de valeurs maximales. DIM est le paramètre utilisé pour retourner les indices de valeurs maximales entre les colonnes lorsqu'elle est définie sur 0 et renvoie des indices de valeurs maximales entre les lignes lorsqu'elle est définie sur 1.