17  K-moyennes

Dans ce chapitre, nous allons explorer plusieurs visualisations de l’agrégation (clustering) par K-moyennes, un algorithme d’apprentissage non supervisé.

Plan du chapitre :

17.1 Visualisation des données de l’iris avec des étiquettes

Nous commençons par une visualisation typique des données de l’iris, comprenant une légende de couleurs pour indiquer les Espèces.

library(animint2)
color.code <- c(
  setosa="#1B9E77",
  versicolor="#D95F02",
  virginica="#7570B3",
  "1"="#E7298A",
  "2"="#66A61E",
  "3"="#E6AB02", 
  "4"="#A6761D")
ggplot()+
  scale_color_manual(values=color.code)+
  geom_point(aes(
    Petal.Length, Petal.Width, color=Species),
    data=iris)+
  coord_equal()

Nous allons illustrer l’algorithme d’agrégation K-moyennes à l’aide de ces deux dimensions.

data.mat <- as.matrix(iris[,c("Petal.Width","Petal.Length")])
head(data.mat)
     Petal.Width Petal.Length
[1,]         0.2          1.4
[2,]         0.2          1.4
[3,]         0.2          1.3
[4,]         0.2          1.5
[5,]         0.2          1.4
[6,]         0.4          1.7
str(data.mat)
 num [1:150, 1:2] 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
 - attr(*, "dimnames")=List of 2
  ..$ : NULL
  ..$ : chr [1:2] "Petal.Width" "Petal.Length"

Pour exécuter K-moyennes, l’hyperparamètre du nombre d’agrégations (K) doit être fixé à l’avance. Ensuite, les K-points de données aléatoires sont sélectionnés comme centres d’agrégation initiaux,

K <- 3
library(data.table)
data.dt <- data.table(data.mat)
set.seed(3)
centers.dt <- data.dt[sample(1:.N, K)]
(centers.mat <- as.matrix(centers.dt))
     Petal.Width Petal.Length
[1,]         0.2          1.4
[2,]         2.1          5.4
[3,]         0.2          1.2
centers.dt[, cluster := factor(1:K)]
centers.dt
   Petal.Width Petal.Length cluster
1:         0.2          1.4       1
2:         2.1          5.4       2
3:         0.2          1.2       3
gg.centers <- ggplot()+
  scale_color_manual(values=color.code)+
  geom_point(aes(
    Petal.Length, Petal.Width),
    color="grey50",
    data=data.dt)+
  geom_point(aes(
    Petal.Length, Petal.Width, color=cluster),
    data=centers.dt)+
  coord_equal()
gg.centers

Ci-dessus, nous avons affiché les deux ensembles de données (centres d’agrégation et données) à l’aide de deux instances de geom_point(). Nous calculons ci-dessous la distance entre chaque point de données et chaque centre d’agrégation,

pairs.dt <- data.table(expand.grid(
  centers.i=1:nrow(centers.mat),
  data.i=1:nrow(data.mat)))

Ces éléments peuvent être visualisés à l’aide d’un geom_point(),

seg.dt <- pairs.dt[, data.table(
  data.i,
  data=data.mat[data.i,],
  center=centers.mat[centers.i,])]
gg.centers+
  geom_segment(aes(
    data.Petal.Length, data.Petal.Width,
    xend=center.Petal.Length, yend=center.Petal.Width),
    size=1,
    data=seg.dt)

Il y a 450 segments superposés ci-dessus, de sorte que l’interactivité serait utile pour mettre en évidence les segments connectés à un point de données précis. Pour ce faire, nous créons une variable de sélection data.i,

animint(
  ggplot()+
    theme_bw()+
    theme_animint(height=300, width=640)+
    scale_color_manual(values=color.code)+
    scale_x_continuous(breaks=seq(1,7,by=0.5))+
    scale_y_continuous(breaks=seq(0, 2.5, by=0.5))+
    geom_point(aes(
      Petal.Length, Petal.Width, color=cluster),
      size=4,
      data=centers.dt)+
    geom_segment(aes(
      data.Petal.Length, data.Petal.Width,
      xend=center.Petal.Length, yend=center.Petal.Width),
      size=1,
      showSelected="data.i",
      data=seg.dt)+
    geom_point(aes(
      Petal.Length, Petal.Width),
      clickSelects="data.i",
      size=2,
      color="grey50",
      data=data.table(data.mat, data.i=1:nrow(data.mat))))

Dans la visualisation des données ci-dessus, vous pouvez cliquer sur un point de données pour afficher les distances entre ce point et chaque centre d’agrégation.

Exercices pour cette section :

  • Modifier les échelles x/y de façon à ce que les mêmes tics soient représentés.
  • Modifier la couleur de chaque segment pour qu’elle soit la même que l’agrégation correspondante.
  • Ajouter une infobulle qui indique la valeur de la distance.
  • Faire varier la largeur du segment en fonction de son optimalité (le segment connecté au centre de l’agrégation la plus proche devrait être mis en évidence par une largeur plus élevée).

17.2 Visualisation des itérations de l’algorithme

Nous calculons le centre d’agrégation le plus proche pour chaque point de données,

pairs.dt[, error := rowSums(
(data.mat[data.i,]-centers.mat[centers.i,])^2)]
(closest.dt <- pairs.dt[, .SD[which.min(error)], by=data.i])
     data.i centers.i error
  1:      1         1  0.00
  2:      2         1  0.00
 ---                       
149:    149         2  0.04
150:    150         2  0.18
(closest.data <- closest.dt[, .(
  data.dt[data.i],
  cluster=factor(centers.i)
)])
     Petal.Width Petal.Length cluster
  1:         0.2          1.4       1
  2:         0.2          1.4       1
 ---                                 
149:         2.3          5.4       2
150:         1.8          5.1       2
(both.dt <- rbind(
  data.table(type="centers", centers.dt),
  data.table(type="data", closest.data)))
        type Petal.Width Petal.Length cluster
  1: centers         0.2          1.4       1
  2: centers         2.1          5.4       2
 ---                                         
152:    data         2.3          5.4       2
153:    data         1.8          5.1       2
ggplot()+
  scale_fill_manual(values=color.code)+
  scale_color_manual(values=c(centers="black", data="grey"))+
  scale_size_manual(values=c(centers=5, data=3))+
  geom_point(aes(
    Petal.Length, Petal.Width, fill=cluster, size=type, color=type),
    data=both.dt)+
  coord_equal()+
  theme_bw()

Ensuite, nous mettons à jour les centres d’agrégation,

new.centers <- closest.dt[, data.table(
  t(colMeans(data.dt[data.i]))
), by=.(cluster=centers.i)]
(new.both <- rbind(
  data.table(type="centers", new.centers),
  data.table(type="data", closest.data)))
        type cluster Petal.Width Petal.Length
  1: centers       1       0.300     1.595918
  2: centers       3       0.175     1.125000
 ---                                         
152:    data       2       2.300     5.400000
153:    data       2       1.800     5.100000
ggplot()+
  scale_fill_manual(values=color.code)+
  scale_color_manual(values=c(centers="black", data="grey"))+
  scale_size_manual(values=c(centers=5, data=3))+
  geom_point(aes(
    Petal.Length, Petal.Width, fill=cluster, size=type, color=type),
    data=new.both)+
  coord_equal()+
  theme_bw()

Les visualisations ci-dessus montrent donc les étapes de la méthode K-moyennes : (1) mise à jour de l’agrégation sur la base du centre le plus proche (2) mise à jour du centre sur la base des données affectées à cette agrégation. Pour visualiser plusieurs itérations des deux étapes ci-dessus, nous pouvons utiliser une boucle for,

set.seed(3)
centers.dt <- data.dt[sample(1:.N, K)]
(centers.mat <- as.matrix(centers.dt))
     Petal.Width Petal.Length
[1,]         0.2          1.4
[2,]         2.1          5.4
[3,]         0.2          1.2
data.and.centers.list <- list()
iteration.error.list <- list()
for(iteration in 1:20){
  pairs.dt[, error := {
    rowSums((data.mat[data.i,]-centers.mat[centers.i,])^2)
  }]
  closest.dt <- pairs.dt[, .SD[which.min(error)], by=data.i]
  iteration.error.list[[iteration]] <- data.table(
    iteration, error=sum(closest.dt[["error"]]))
  iteration.both <- rbind(
    data.table(type="centers", centers.dt, cluster=1:K),
    closest.dt[, data.table(
      type="data", data.dt[data.i], cluster=factor(centers.i))])
  data.and.centers.list[[iteration]] <- data.table(
    iteration, iteration.both)
  new.centers <- closest.dt[, data.table(
    t(colMeans(data.dt[data.i]))
  ), keyby=.(cluster=centers.i)]
  centers.dt <- new.centers[, names(centers.dt), with=FALSE]
  centers.mat <- as.matrix(centers.dt)
}
(data.and.centers <- do.call(rbind, data.and.centers.list))
      iteration    type Petal.Width Petal.Length cluster
   1:         1 centers         0.2          1.4       1
   2:         1 centers         2.1          5.4       2
  ---                                                   
3059:        20    data         2.3          5.4       2
3060:        20    data         1.8          5.1       2
(iteration.error <- do.call(rbind, iteration.error.list))
    iteration     error
 1:         1 123.63000
 2:         2  85.82705
---                    
19:        19  31.37136
20:        20  31.37136

Nous commençons par créer un graphique d’ensemble avec une courbe d’erreur qui servira à sélectionner la taille du modèle,

gg.err <- ggplot()+
  theme_bw()+
  geom_point(aes(
    iteration, error),
    data=iteration.error)+
  make_tallrect(iteration.error, "iteration", alpha=0.3)

Nous créons également un graphique qui indique l’itération en cours,

gg.iteration <- ggplot()+
  scale_fill_manual(values=color.code)+
  scale_color_manual(values=c(centers="black", data=NA))+
  scale_size_manual(values=c(centers=5, data=2))+
  geom_point(aes(
    Petal.Length, Petal.Width, fill=cluster, size=type, color=type),
    showSelected="iteration",
    data=data.and.centers)+
  coord_equal()+
  theme_bw()
gg.iteration

La combinaison des deux graphiques permet d’obtenir une visualisation des données interactive,

animint(gg.err, gg.iteration)

17.3 Résumé du chapitre et exercices

Exercices :

  • Faites en sorte que les centres apparaissent toujours au premier plan (au-dessus des données).
  • Ajoutez des transitions en douceur.
  • Ajoutez une animation sur la variable d’itération.
  • Le code actuel impose un nombre maximum d’itérations, il est donc possible que les dernières ne progressent pas. Par exemple, dans l’image ci-dessus, l’itération 16 est la dernière à réduire l’erreur (les itérations 17 à 20 n’entraînent aucune diminution). Modifiez le code pour qu’il arrête l’itération s’il n’y a pas de diminution de l’erreur.
  • La visualisation actuelle n’a qu’un seul cadre d’animation (sous-ensemble showSelected) par itération (la moyenne est affichée avant sa mise à jour). Ajoutez un autre cadre d’animation qui montre la moyenne après la mise à jour.
  • Ajoutez des segments interactifs qui montrent la distance entre chaque point de données et chaque centre d’agrégation (comme dans le premier animint de cette page).
  • Ajoutez les fonctionnalités décrites dans les exercices de la section précédente.
  • Calculez les résultats pour différentes graines aléatoires, puis affichez les taux d’erreur correspondants sur le graphique d’aperçu des erreurs. Permettez ensuite à l’utilisateur de sélectionner n’importe lequel de ces résultats.
  • Calculez les résultats pour plusieurs nombres d’agrégations (K). Calculez l’indice de Rand ajusté en utilisant pdfCluster::adj.rand.index(species, cluster) pour chaque K et graine aléatoire différents. Ajoutez un graphique d’ensemble qui montre la valeur ARI de chaque modèle et autorise la sélection du nombre d’agrégations.
  • Effectuez une visualisation similaire en utilisant un autre ensemble de données tel que data("penguins", package="palmerpenguins").

Ensuite, dans le chapitre 18, nous vous expliquerons comment visualiser l’algorithme d’apprentissage par descente de gradient pour l’apprentissage des réseaux neuronaux.