17  K-moyennes

Dans ce chapitre, nous explorerons plusieurs visualisations de l’agrégation (clustering) par K-moyennes, un algorithme d’apprentissage non supervisé.

Plan du chapitre :

17.1 Visualisation des données Iris avec des étiquettes

Nous commençons par une visualisation typique des données de l’iris, comprenant une légende de couleurs pour les espèces.

library(animint2)
color.code <- c(
  setosa="#1B9E77",
  versicolor="#D95F02",
  virginica="#7570B3",
  "1"="#E7298A",
  "2"="#66A61E",
  "3"="#E6AB02", 
  "4"="#A6761D")
ggplot()+
  scale_color_manual(values=color.code)+
  geom_point(aes(
    Petal.Length, Petal.Width, color=Species),
    data=iris)+
  coord_equal()

Nous allons illustrer l’algorithme d’agrégation K-moyennes à l’aide de deux dimensions.

data.mat <- as.matrix(iris[,c("Petal.Width","Petal.Length")])
head(data.mat)
     Petal.Width Petal.Length
[1,]         0.2          1.4
[2,]         0.2          1.4
[3,]         0.2          1.3
[4,]         0.2          1.5
[5,]         0.2          1.4
[6,]         0.4          1.7
str(data.mat)
 num [1:150, 1:2] 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
 - attr(*, "dimnames")=List of 2
  ..$ : NULL
  ..$ : chr [1:2] "Petal.Width" "Petal.Length"

Pour exécuter K-moyennes, l’hyperparamètre du nombre d’agrégations (K) doit être fixé à l’avance. Ensuite, les K-points de données aléatoires sont sélectionnés comme centres d’agrégation initiaux.

K <- 3
library(data.table)
data.dt <- data.table(data.mat)
set.seed(3)
centers.dt <- data.dt[sample(1:.N, K)]
(centers.mat <- as.matrix(centers.dt))
     Petal.Width Petal.Length
[1,]         0.2          1.4
[2,]         2.1          5.4
[3,]         0.2          1.2
centers.dt[, cluster := factor(1:K)]
centers.dt
   Petal.Width Petal.Length cluster
1:         0.2          1.4       1
2:         2.1          5.4       2
3:         0.2          1.2       3
gg.centers <- ggplot()+
  scale_color_manual(values=color.code)+
  geom_point(aes(
    Petal.Length, Petal.Width),
    color="grey50",
    data=data.dt)+
  geom_point(aes(
    Petal.Length, Petal.Width, color=cluster),
    data=centers.dt)+
  coord_equal()
gg.centers

Ci-dessus, nous avons affiché les deux ensembles de données (centres d’agrégation et données) à l’aide de deux geom_point(). Nous calculons ci-dessous la distance entre chaque point de données et chaque centre d’agrégation :

pairs.dt <- data.table(expand.grid(
  centers.i=1:nrow(centers.mat),
  data.i=1:nrow(data.mat)))

Ces éléments peuvent être visualisés à l’aide d’un geom_point().

seg.dt <- pairs.dt[, data.table(
  data.i,
  data=data.mat[data.i,],
  center=centers.mat[centers.i,])]
gg.centers+
  geom_segment(aes(
    data.Petal.Length, data.Petal.Width,
    xend=center.Petal.Length, yend=center.Petal.Width),
    size=1,
    data=seg.dt)

Il y a 450 segments superposés ci-dessus, si bien que l’interactivité serait utile pour mettre en évidence les segments connectés à un point de données précis. À cet effet, nous créons ci-dessous une variable de sélection data.i :

animint(
  ggplot()+
    theme_bw()+
    theme_animint(height=300, width=640)+
    scale_color_manual(values=color.code)+
    scale_x_continuous(breaks=seq(1,7,by=0.5))+
    scale_y_continuous(breaks=seq(0, 2.5, by=0.5))+
    geom_point(aes(
      Petal.Length, Petal.Width, color=cluster),
      size=4,
      data=centers.dt)+
    geom_segment(aes(
      data.Petal.Length, data.Petal.Width,
      xend=center.Petal.Length, yend=center.Petal.Width),
      size=1,
      showSelected="data.i",
      data=seg.dt)+
    geom_point(aes(
      Petal.Length, Petal.Width),
      clickSelects="data.i",
      size=2,
      color="grey50",
      data=data.table(data.mat, data.i=1:nrow(data.mat))))

Dans la visualisation des données ci-dessus, vous pouvez cliquer sur un point de données pour afficher les distances entre ce point et chaque centre d’agrégation.

Exercices pour cette section :

  • Modifiez les échelles x/y de façon à ce que les mêmes graduations soient représentées.
  • Modifiez la couleur de chaque segment pour qu’elle soit la même que l’agrégation correspondante.
  • Ajoutez une infobulle qui indique la valeur de la distance.
  • Faites varier la largeur du segment en fonction de son optimalité (le segment connecté au centre de l’agrégation la plus proche devrait être mis en évidence par une largeur plus importante).

17.2 Visualisation des itérations de l’algorithme

Nous calculons le centre d’agrégation le plus proche pour chaque point de données.

pairs.dt[, error := rowSums(
(data.mat[data.i,]-centers.mat[centers.i,])^2)]
(closest.dt <- pairs.dt[, .SD[which.min(error)], by=data.i])
     data.i centers.i error
  1:      1         1  0.00
  2:      2         1  0.00
 ---                       
149:    149         2  0.04
150:    150         2  0.18
(closest.data <- closest.dt[, .(
  data.dt[data.i],
  cluster=factor(centers.i)
)])
     Petal.Width Petal.Length cluster
  1:         0.2          1.4       1
  2:         0.2          1.4       1
 ---                                 
149:         2.3          5.4       2
150:         1.8          5.1       2
(both.dt <- rbind(
  data.table(type="centers", centers.dt),
  data.table(type="data", closest.data)))
        type Petal.Width Petal.Length cluster
  1: centers         0.2          1.4       1
  2: centers         2.1          5.4       2
 ---                                         
152:    data         2.3          5.4       2
153:    data         1.8          5.1       2
ggplot()+
  scale_fill_manual(values=color.code)+
  scale_color_manual(values=c(centers="black", data="grey"))+
  scale_size_manual(values=c(centers=5, data=3))+
  geom_point(aes(
    Petal.Length, Petal.Width, fill=cluster, size=type, color=type),
    data=both.dt)+
  coord_equal()+
  theme_bw()

Ensuite, nous mettons à jour les centres d’agrégation.

new.centers <- closest.dt[, data.table(
  t(colMeans(data.dt[data.i]))
), by=.(cluster=centers.i)]
(new.both <- rbind(
  data.table(type="centers", new.centers),
  data.table(type="data", closest.data)))
        type cluster Petal.Width Petal.Length
  1: centers       1       0.300     1.595918
  2: centers       3       0.175     1.125000
 ---                                         
152:    data       2       2.300     5.400000
153:    data       2       1.800     5.100000
ggplot()+
  scale_fill_manual(values=color.code)+
  scale_color_manual(values=c(centers="black", data="grey"))+
  scale_size_manual(values=c(centers=5, data=3))+
  geom_point(aes(
    Petal.Length, Petal.Width, fill=cluster, size=type, color=type),
    data=new.both)+
  coord_equal()+
  theme_bw()

Les visualisations ci-dessus illustrent les étapes de la méthode K-moyennes : (1) mise à jour de l’agrégation sur la base du centre le plus proche (2) mise à jour du centre à partir des données affectées à cette agrégation. Pour visualiser plusieurs itérations des deux étapes ci-dessus, nous utilisons une boucle for.

set.seed(3)
centers.dt <- data.dt[sample(1:.N, K)]
(centers.mat <- as.matrix(centers.dt))
     Petal.Width Petal.Length
[1,]         0.2          1.4
[2,]         2.1          5.4
[3,]         0.2          1.2
data.and.centers.list <- list()
iteration.error.list <- list()
for(iteration in 1:20){
  pairs.dt[, error := {
    rowSums((data.mat[data.i,]-centers.mat[centers.i,])^2)
  }]
  closest.dt <- pairs.dt[, .SD[which.min(error)], by=data.i]
  iteration.error.list[[iteration]] <- data.table(
    iteration, error=sum(closest.dt[["error"]]))
  iteration.both <- rbind(
    data.table(type="centers", centers.dt, cluster=1:K),
    closest.dt[, data.table(
      type="data", data.dt[data.i], cluster=factor(centers.i))])
  data.and.centers.list[[iteration]] <- data.table(
    iteration, iteration.both)
  new.centers <- closest.dt[, data.table(
    t(colMeans(data.dt[data.i]))
  ), keyby=.(cluster=centers.i)]
  centers.dt <- new.centers[, names(centers.dt), with=FALSE]
  centers.mat <- as.matrix(centers.dt)
}
(data.and.centers <- do.call(rbind, data.and.centers.list))
      iteration    type Petal.Width Petal.Length cluster
   1:         1 centers         0.2          1.4       1
   2:         1 centers         2.1          5.4       2
  ---                                                   
3059:        20    data         2.3          5.4       2
3060:        20    data         1.8          5.1       2
(iteration.error <- do.call(rbind, iteration.error.list))
    iteration     error
 1:         1 123.63000
 2:         2  85.82705
---                    
19:        19  31.37136
20:        20  31.37136

Nous créons d’abord un graphique d’ensemble avec une courbe d’erreur pour sélectionner la taille du modèle.

gg.err <- ggplot()+
  theme_bw()+
  geom_point(aes(
    iteration, error),
    data=iteration.error)+
  make_tallrect(iteration.error, "iteration", alpha=0.3)

Nous créons également un graphique montrant l’itération en cours.

gg.iteration <- ggplot()+
  scale_fill_manual(values=color.code)+
  scale_color_manual(values=c(centers="black", data=NA))+
  scale_size_manual(values=c(centers=5, data=2))+
  geom_point(aes(
    Petal.Length, Petal.Width, fill=cluster, size=type, color=type),
    showSelected="iteration",
    data=data.and.centers)+
  coord_equal()+
  theme_bw()
gg.iteration

La combinaison des deux graphiques produit une visualisation interactive.

animint(gg.err, gg.iteration)

17.3 Résumé du chapitre et exercices

Exercices :

  • Faites en sorte que les centres apparaissent toujours au premier plan (au-dessus des données).
  • Ajoutez des transitions fluides.
  • Animez la variable d’itération.
  • Le code actuel impose un nombre maximal d’itérations, il est donc possible que les dernières ne progressent pas. Par exemple, l’itération 16 est la dernière à réduire l’erreur (les itérations 17 à 20 n’entraînent aucune diminution). Modifiez le code pour arrêter l’itération en absence de diminution de l’erreur.
  • La visualisation actuelle n’a qu’un cadre d’animation (sous-ensemble showSelected) par itération (la moyenne est affichée avant sa mise à jour). Ajoutez un autre cadre d’animation qui montre la moyenne après la mise à jour.
  • Ajoutez des segments interactifs qui montrent la distance entre chaque point de données et chaque centre d’agrégation (comme dans le premier animint de cette page).
  • Ajoutez les fonctionnalités décrites dans les exercices de la section précédente.
  • Calculez les résultats pour différentes graines aléatoires, puis affichez les taux d’erreur correspondants sur le graphique d’aperçu des erreurs. Permettez ensuite à l’utilisateur de sélectionner n’importe lequel de ces résultats.
  • Calculez les résultats pour plusieurs nombres d’agrégations (K). Calculez l’indice de Rand ajusté en utilisant pdfCluster::adj.rand.index(species, cluster) pour chaque K et graine aléatoire. Ajoutez un graphique d’ensemble qui montre la valeur ARI de chaque modèle et autorise la sélection du nombre d’agrégations.
  • Créez une visualisation similaire d’un autre ensemble de données tel que data("penguins", package="palmerpenguins").

Dans le chapitre 18, nous vous expliquerons comment visualiser l’algorithme d’apprentissage par descente de gradient pour l’apprentissage des réseaux neuronaux.