Computer-aided diagnosis of retinopathy based on vision transformer
Abstract
Age-related Macular Degeneration (AMD) and Diabetic Macular Edema (DME) are two common retinal diseases for elder people that may ultimately cause irreversible blindness. Timely and accurate diagnosis is essential for the treatment of these diseases. In recent years, computer-aided diagnosis (CAD) has been deeply investigated and effectively used for rapid and early diagnosis. In this paper, we proposed a method of CAD using vision transformer to analyze optical coherence tomography (OCT) images and to automatically discriminate AMD, DME, and normal eyes. A classification accuracy of 99.69% was achieved. After the model pruning, the recognition time reached 0.010 s and the classification accuracy did not drop. Compared with the Convolutional Neural Network (CNN) image classification models (VGG16, Resnet50, Densenet121, and EfficientNet), vision transformer after pruning exhibited better recognition ability. Results show that vision transformer is an improved alternative to diagnose retinal diseases more accurately.
1. Introduction
Age-related macular degeneration (AMD) and diabetic macular edema (DME) are two common retinal diseases that occur among the elders. AMD mostly emerges in people over 45years old and the prevalence of AMD disease increases with age gradually.1,2,3 DME, which is the direct cause of vision impairment in most diabetic patients, has a high incidence all over the world.4,5 Timely and accurate diagnosis of AMD and DME is an important premise for effective treatment in lieu of irreversible blindness (to prevent irreversible blindness). The noninvasive imaging through the three-dimensional cellular cytolytic tissue, synchronous medical imaging technology, can improve early medical diagnosis and contribute to the development of novel therapies. Optical coherence tomography (OCT) is a noninvasive, high-resolution optical medical diagnostic imaging technique, which enables in vivo cross-sectional tomographic visualization of the internal microstructure of eyes. In recent years, it has become a golden standard for eye disease diagnosis.6,7,8,9,10,11,12
Diagnosing retinal diseases using OCT images is currently the most effective CAD of retinopathy. In recent years, deep learning has developed rapidly, and it has provided many solutions in the field of medical diagnosis. Especially, Convolutional Neural Network (CNN) has achieved great success in medical image classification. A large number of researchers have made great achievements in the field of using CNN to classify OCT images for disease diagnosis. Shih et al. proposed to classify four types of OCT images using the VGG16 model, which involved three retinal diseases and normal retina. The accuracy of the model on the data test set reached 99%. The results demonstrate the effectiveness of the deep learning algorithm in the diagnosis of retinal diseases.13 Zhang Quan et al. proposed a multi-scale deep learning model using OCT images to identify diabetic edema macula, and the model achieved an accuracy rate of 94.5%. By comparison with other models, the author proved that this model provided a better ability to recognize low-quality medical images.14 Saratxaga Cristina et al. used deep learning algorithm to classify OCT images of mouse colons, and the recognition accuracy of their model reached 96.65%, which effectively proved the excellent performance of deep learning in computer-aided diagnosis (CAD) of colonic lesions.15 Luo Yuemei used OCT images combined with deep learning to assist in the diagnosis of sebaceous glands. The classification accuracy of the model proposed in this paper reached 97.9%, which is of great help to medical personnel in the diagnosis and treatment of serum problems.16 Potapenko Ivan et al. used noisy OCT image data sets for deep learning training to complete the identification task of AMD disease. The recognition accuracy of the model reached 90.9%, which proved that the model could assist medical personnel in the clinical diagnosis of AMD.17
It should be noticed that most of the above-mentioned studies are based on CNN. CNN has the advantage of being able to extract image features very well, which has been verified by a large number of scholars. However, there are still little researches on vision transformer CAD. Vision transformer is a new image classification model that has been proposed in 2020.18 It does not rely on any CNN, but it is completely performed based on the transformer structure, which has a different feature extraction method from CNN. Yakoub Bazi et al. have proved that vision transformer has better classification capabilities than CNN to solve image classification problems. They applied vision transformer to remote sensing image classification and tested multiple remote sensing image data sets. Experimental results showed that the vision transformer classification accuracy rate of remote sensing images exceeds the results based on the CNN.19
In this paper, we integrated vision transformer and OCT images to improve the diagnosis of retinal diseases. Aiming at the two common retinal diseases of AMD and DME, we collected OCT images of AMD and DME diseases, as well as images of normal ocular fundus. We then trained the vision transformer to classify these three types of OCT images in order to realize the diagnosis of retinal diseases. In Sec. 2, we introduced the dataset and vision transformer used in this paper. In Sec. 3, we analyzed the experimental results of this paper, discussed the impact of model pruning, and compared the results with that using the CNN. The full study was summarized in Sec. 4.
2. Materials and Methods
2.1. Dataset
The dataset used in this paper comes from the OCT fundus images of 15 normal people, 15 AMD patients, and 15 DME patients collected by Duke University.20 The dataset includes 1407 OCT images of normal retinal, 723 OCT images of AMD, and 1101 OCT images of DME. Examples of these three types of OCT images are shown in Fig. 1.

Fig. 1. Examples of these three types of OCT images. (a) Normal, (b) AMD, and (c) DME.
As shown in Fig. 1(a), normal retina OCT images provide the inner retinal layers including the nerve fiber layer (NFL), ganglion cell layer (GCL), inner plexiform layer (IPL), inner nuclear layer (INL), and the outer layers including the outer plexiform layer (OPL), outer nuclear layer (ONL), myoid and ellipsoid zone (MEZ), the outer segment of photoreceptors (OS) and the retinal pigment epithelium (RPE). OCT images of AMD fundus shown in Fig. 1(b) present a large number of drusen/choroidal neovascularization (CNV) formations, which results in submacular hemorrhage and leakage. As we can see from Fig. 1(c), OCT images of DME present with retinal edema, hemorrhage, cystic macular edema, and subretinal fluid. OCT fundus images of AMD shown in Fig. 1(b) present a large number of larger drusen/CNV formations, which results in submacular blood and protein leakage. We can see from Fig. 1(c) that the OCT fundus images of DME are present with retinal swelling, hemorrhage, cystic macular edema, and subretinal fluid.
2.2. Data preprocessing
It is necessary to preprocess the images before inputting the images to model for recognition. Image resizing and normalization are the main data processing steps. In this study, the images’ size was adjusted to 224 × 224, and then the outputs of them were normalized.21,22,23
2.3. Vision transformer
Vision transformer is completely implemented based on the transformer structure, which is widely used in the field of Natural Language Processing (NLP). The transformer structure consists of a set of encoder components and a group of decoder components. Vision transformer is an image classification model and does not require decoder components, so there is only an encoder component in the vision transformer.
The encoder component is composed of a stack of six identical encoders. Each encoder consists of a multi-head attention layer and a feed-forward layer. Both layers contain residual connection and LayerNorm. MLP is a multilayer perceptron. The structure of an encoder component is shown in Fig. 2.

Fig. 2. The structure of an encoder component.
The multi-head attention in the encoder component structure is a self-attention structure, which allows the model to focus on different aspects of information. Scaled dot-product attention is the attention score calculation method of multi-head attention. The structure of the multi-head attention and the structure of scaled dot-product attention are shown in Fig. 3.

Fig. 3. The structure of the multi-head attention and the structure of scaled dot-product attention.
Multi-head attention is shown in the following equations :
Linear embedding layer is a crucial structure in vision transformer. The images are divided into multiple patches after being input into vision transformer, and each patch is implemented for patch embedding processing. Embedding is a spatial mapping method commonly used in NLP, which maps high-dimensional vectors to low-dimensional spaces. Patch embedding is the embedding operation of a patch, and each patch is flattened into a one-dimensional tensor. After the patch embedding operation, positional embedding and class embedding are added to the transformer encoder. After being output by the transformer encoder, it will go through an MLP head structure, which is composed of a fully connected layer and a Gaussian Error Linear Unit (GELU) activation function. The structure of the MLP is shown in Fig. 4.

Fig. 4. The structure of the MLP.
The equation of the GELU activation function is

Fig. 5. The structure of vision transformer.
2.4. Symmetric cross-entropy loss function
The loss function used in this paper is the asymmetric cross-entropy loss function. The symmetrical cross-entropy loss function can reduce the impact of noise in the data set on training and can prevent overfitting.24 The definition equation of the symmetric cross-entropy loss function is
2.5. Experimental environment
The experimental environment of this paper is as follows: The hardware environment is Intel Core i7-9700f processor, NVIDIA RTX2060s 8GB graphics card and 16 GB memory; the software environment is Win10 system, Python 3.7, and the deep learning framework used is PyTorch.25
2.6. Evaluation standard
In this paper, accuracy was adopted as the standard to evaluate the performance of the model. The definition of accuracy is shown in Eq. (8). In addition, we also analyzed the recognition ability of each type of OCT image by using the confusion matrix, which is the comparison matrix between the predicted results and the real results,
2.7. Experimental procedure and hyperparameter settings
The experimental process and parameter settings of this paper were as follows:
(1) | The data set was divided into training set, validation set, and test set in an 8:1:1 ratio. The training set was used for training the model, the validation set was used for verifying whether the model had been fitted, and the test set was used for testing the classification ability of the model. | ||||
(2) | The data of the model were preprocessed. | ||||
(3) | We used the Ranger optimizer as the optimizer; the learning rate was set as 0.00003, which is the commonly used learning rate for the training classification model; we trained the model 100 times, which is sufficient for the model with small data set. | ||||
(4) | The parameters of the model were saved with the highest accuracy of the model on the validation set. | ||||
(5) | The test was performed on the test set and the accuracy of the test set was obtained. |
3. Results and Discussion
3.1. Experimental results
The experiment was performed according to the experimental process described in Sec. 2.7, and the error changes of the training set and the accuracy changes of the validation set were recorded, which were used as the basis for analyzing the fitting ability of the model. The error changes of the training set and the accuracy changes of the validation set were shown in Fig. 6.

Fig. 6. (a) The error changes of the training set; (b) the accuracy changes of the validation set.
It can be seen that the error on the training set dropped very quickly and tends to be stable after the 40th epoch. At the same time, the accuracy rate on the validation set rises very fast and tends to be stable after the 40th epoch. As shown by the error change on the training set and the accuracy change on the validation set, it can be determined that the vision transformer fits normally without overfitting.
The accuracy of the model on the test set is 99.69% and the confusion matrix of the test set is shown in Fig. 7.

Fig. 7. The confusion matrix of the test set.
Meanwhile, we can conclude that the vision transformer’s classification of OCT fundus images for AMD and DME is completely accurate, which indicates that there will be no missed diagnosis in practical applications. For the OCT images of normal fundus, there is one case of classification error. In the actual diagnosis, a patient will have multiple OCT images for diagnosis, so the misclassification of one normal OCT image will not affect the final diagnosis of retinopathy.
In order to verify the effectiveness of the method proposed in this paper, the model presented in this paper was compared with the traditional CNN. We chose VGG16,26 Resnet50,27 Densenet121,28 and EfficentNet29 as the comparison models. The experimental procedures and hyperparameter settings of these models were consistent with the vision transformer.
Model pruning is a method to reduce the memory demand and calculation demand of the model and improve the calculation speed of the model. By encouraging the sparsity of channels in the vision transformer, the important channels are selected and those channels with zero or small coefficients are discarded to achieve an efficient classification model.30 In this paper, the model pruning is performed on the vision transformer to improve the recognition speed of a single OCT image under the condition that the recognition ability of the model does not decrease significantly.
The accuracy comparison of the test set between vision transformer, vision transformer after pruning, and four CNN models and the recognition time comparison of a single image are shown in Table 1.
Model | Accuracy | Time |
---|---|---|
VGG16 | 98.51% | 0.014s |
Resnet50 | 97.32% | 0.089s |
Densenet121 | 97.02% | 0.091s |
Efficientnet | 34.16% | 0.032s |
Vision transformer | 99.69% | 0.017s |
Vision transformer after pruning | 99.69% | 0.010s |
As we can see from Table 1, among the CNN image classification models, VGG16 has the highest classification accuracy of 98.51%, but it is still inferior to the classification accuracy of vision transformer. With attention mechanism, vision transformer can focus on image regions that are semantically relevant to the classification target, so it can obtain higher accuracy. In terms of the recognition speed of a single image, both VGG16 and vision transformer are faster than other CNN models in the recognition speed of a single image. The recognition speed of a single image of vision transformer after pruning is the fastest; it only costs 0.010s. This speed is faster than all other models, and its recognition accuracy does not decrease, remaining at 99.69%. Considering the recognition accuracy and recognition speed, vision transformer after pruning is superior to CNN models in recognizing OCT images of fundus.
In order to further verify the validity of the model proposed in this paper, we compare the results of the model in this paper with the research results of related literature, and the comparison is shown in Table 2.
Model | Methods | Diseases | Accuracy |
---|---|---|---|
Literature20 | SVM | AMD DME | 95.56% |
Literature31 | Dictionary-learning | AMD DME | 98.38% |
Literature32 | CNN | AMD DME | 91.33% |
Literature33 | CNN | AMD DME | 96.66% |
Literature34 | CNN | AMD DME | 94.20% |
Literature35 | CNN | DME | 96% |
This paper | Vision transformer | AMD DME | 99.69% |
Table 2 shows that Refs. 20 and 31 used machine learning algorithms and the accuracy of the model using dictionary-learning reached 98.38%, which exceeds the results of other studies using CNN. In the literature using CNN, we can see that the recognition accuracy of the model using multiscale and multipath CNN architecture is higher than that of general CNN models. The classification accuracy of the model proposed in this paper is the highest, surpassing the other researches, which proves that the recognition ability of vision transformer for OCT fundus images is stronger than CNN models and traditional machine learning algorithms.
4. Conclusions
AMD and DME are two retinal diseases that seriously harm the eyesight of the elderly. Timely diagnosis and treatment are very important for patients. In this paper, we presented a CAD method using vision transformer to classify OCT fundus images, which effectively diagnosed AMD and DME retinal diseases. The data set used in this paper came from the OCT fundus images collected by Duke University. The loss function and optimizer of the model in this paper selected the symmetric cross-entropy loss function and the Ranger optimizer. The classification accuracy of our model for AMD, DME, and normal OCT fundus images reached 99.69%. Then, we pruned the model and compared the effect of vision transformer after pruning with the CNN image classification models. We proved that vision transformer had the best recognition ability, and vision transformer after pruning had the fastest recognition speed without any decrease in recognition accuracy. The model proposed in this paper can better realize the CAD of retinal diseases.
Acknowledgments
This work was supported by the Science and Technology innovation project of Shanghai Science and Technology Commission (19441905800), the Natural National Science Foundation of China (62175156, 81827807, 8210041176, 82101177, 61675134), the Project of State Key Laboratory of Ophthalmology, Optometry and Visual Science, Wenzhou Medical University (K181002), and the Key R&D Program Projects in Zhejiang Province (2019C03045).