Expand this Topic clickable element to expand a topic
Skip to content
Optica Publishing Group

Student becomes teacher: training faster deep learning lightweight networks for automated identification of optical coherence tomography B-scans of interest using a student-teacher framework

Open Access Open Access

Abstract

This work explores a student-teacher framework that leverages unlabeled images to train lightweight deep learning models with fewer parameters to perform fast automated detection of optical coherence tomography B-scans of interest. Twenty-seven lightweight models (LWMs) from four families of models were trained on expert-labeled B-scans (∼70 K) as either “abnormal” or “normal”, which established a baseline performance for the models. Then the LWMs were trained from random initialization using a student-teacher framework to incorporate a large number of unlabeled B-scans (∼500 K). A pre-trained ResNet50 model served as the teacher network. The ResNet50 teacher model achieved 96.0% validation accuracy and the validation accuracy achieved by the LWMs ranged from 89.6% to 95.1%. The best performing LWMs were 2.53 to 4.13 times faster than ResNet50 (0.109s to 0.178s vs. 0.452s). All LWMs benefitted from increasing the training set by including unlabeled B-scans in the student-teacher framework, with several models achieving validation accuracy of 96.0% or higher. The three best-performing models achieved comparable sensitivity and specificity in two hold-out test sets to the teacher network. We demonstrated the effectiveness of a student-teacher framework for training fast LWMs for automated B-scan of interest detection leveraging unlabeled, routinely-available data.

© 2021 Optical Society of America under the terms of the OSA Open Access Publishing Agreement

1. Introduction

Optical coherence tomography (OCT) technology has transformed ophthalmic imaging, enabling physicians to identify pathological features associated with multiple diseases of the retina. OCT provides volumetric three-dimensional data, including en face visualization of the retinal layers, and can detect small structural changes that are not apparent in other imaging modalities [1]. Currently, OCT imaging is used near ubiquitously to diagnose and monitor the progression of many ophthalmic diseases such as age-related macular degeneration [2], diabetic macular edema [3], and glaucoma [4]. However, commercial OCT devices do not routinely provide automated diagnoses, and detection and monitoring of disease via OCT require expert interpretation of the images by an ophthalmologist. Barriers to developing automated diagnostic algorithms for OCT images include lack of large training datasets, standardized methods for image acquisition and processing, agreed-upon evaluation metrics, and limitations in computing power [5].

Despite these challenges, researchers are continuing with their efforts in applying artificial intelligence approaches in automated disease detection. Deep learning is particularly effective for this task, as these models can automatically identify features associated with a given diagnosis without requiring expert interpretation. Indeed, deep learning models have already been developed to identify pathologies such as age-related macular degeneration [6], macular edema [7], glaucomatous optic neuropathy [8], and diabetic retinopathy [9]. Recent studies have demonstrated the potential feasibility of automated multiple disease classification of OCT images using deep learning, as well as the challenges associated with this task [1012]. Additionally, a deep learning system has been developed to identify OCT B-scans of interest, using a binary classification approach where the abnormal class was composed of multiple retinal pathologies [13,14].

Transfer learning, in which an existing classification model pre-trained on a very large, generic image dataset is fine-tuned for a specific medical imaging task, is the usual approach for developing these models. Recently, however, the utility of transfer learning for medical imaging has been called into question given the differing image statistics of medical images and the relatively few number of classes compared to typical object recognition tasks [15]. The state of the art deep learning models developed for object recognition typically have large numbers of parameters and require significant training time [12,16], and may be overparameterized for the task of identifying one or more pathologies on retinal OCT images [15]. Many clinical applications demand a model that performs quickly and is compatible with portable devices, as is the case for the ever-growing field of mobile health (mHealth) [17].

Lightweight models (LWMs), those with fewer parameters and operations, can provide potential advantages over large deep learning models and transfer learning approaches [15,18,19]. While many lightweight models have been shown to perform comparably to much larger models [15,1921], training these models may require specialized approaches to achieve comparable accuracy to large, state-of-the-art models. Knowledge distillation enables improved training of small (student) models by distilling knowledge from a larger (teacher) network, also known as “student-teacher framework [22].” There are various techniques employed to achieve knowledge distillation. One method achieves distillation through semi-supervised learning where a larger network generates labels for unlabeled data that are used to train a smaller network [22,23]. Another method is soft target training where the teacher [24] generates soft target labels for the training data and these soft targets are used in combination with one-hot encoded labels (hard targets). Combining the concept of semi-supervised learning and soft-target learning, unlabeled images can be used to distill knowledge from the teacher network to the student networks through soft-target learning. In the case of multiclass classification, multiple specialist networks can be ensembled and their knowledge distilled to a student network [24].

In this study, we explore a student-teacher framework to perform semi-supervised learning with the goal of training LWMs to perform fast automated detection of OCT B-scans of interest. We take advantage of widely available unlabeled images, with the goal of developing a model that can provide fast and accurate results, suitable for a clinical setting.

2. Methods

2.1 Optical coherence tomography dataset

A total of 598 OCT macular volume cube scans (76,396 B-scans) were retrospectively obtained from nine clinical sites in the United States, Germany, Portugal and Singapore. All macular volume cube scans (512 × 128) were acquired using the CIRRUS 4000 and 5000 (ZEISS, Dublin, CA). Five hundred ninety-eight unique patients were included with a variety of retinal conditions. Two retina specialists first labeled each B-scan for image quality as “gradable” or “ungradable”, resulting in the exclusion of 148 B-scans. Then the specialists annotated the remaining 76,396 B-scan images based on 8 different pathologies: intraretinal fluid, subretinal fluid, disruption of inner retinal layers, disruption of the vitreoretinal interface (VRI), retinal pigment epithelium (RPE) atrophy, RPE elevation, inner segment (IS)/outer segment (OS) disruption, or other retinal pathologies. A B-scan was labeled abnormal if at least one specialist annotated at least one of the 8 pathological categories for the image. The images were then split at the subject level into training (80%) and validation (20%) data sets. This resulted in a training set with 36,223 normal OCT images and 24,835 images of interest from 478 patients, and a validation set of 9,475 normal and 5,863 images of interest from 120 patients. In addition to the labeled images, a set of 478,588 unlabeled images from 3,148 patients collected using the CIRRUS 4000 and 5000 devices from multiple sites was available. See Table 1 for a summary of the number of patients and images.

Tables Icon

Table 1. Study data used for training and validation.

Two hold-out test sets were set aside for the final performance testing of the top-performing lightweight models; both the sets were collected from clinical sites that were not used for gathering training or validation data. Hold-out test set-1 contained 200 OCT macular volume cubes (25,600 B-scan images) retrospectively obtained from three clinical sites in the United States and Austria. All the OCT cubes in this hold-out data were acquired using the CIRRUS 4000 and 5000 (ZEISS, Dublin, CA) from three sites. Two retina specialists labeled each B-scan for the same 8 different retinal pathologies and a B-scan was labeled abnormal if at least one specialist annotated at least one of the pathologies in that image. The analysis of the annotations showed excellent agreement between the two specialists (Cohen’s kappa = 0.8922). Hold-out test set-2 contained 225 OCT macular volume cubes (28,800 B-scan images) retrospectively obtained from six clinical sites in the United States and Austria. All the OCT cubes in this hold-out data were acquired using the CIRRUS 6000 (ZEISS, Dublin, CA) from seven sites. Three optometrists labeled each B-scans for the same 8 different retinal pathologies and a B-scan was labeled abnormal if at least 2 optometrists labeled at least one of the pathologies in that image. The analysis of the annotations showed moderate to substantial agreement between the three optometrists (Cohen’s kappa between optometrist 1&2 = 0.4859, Cohen’s kappa between optometrist 2&3 = 0.7992 and Cohen’s kappa between optometrist 1&3 = 0.5406).

2.2 Teacher network training

The deep learning architecture for the automatic detection of B-scans of interest was developed using a ResNet-50 network. A 3-channel ResNet-50 [25] neural network architecture was modified by adding inverted drop-out followed by Softmax activation. The training set was resized to 224 x 224 and augmented using rotation, horizontal flip, and vertical shift to form a set of B-scans for training. The modified 3-channel Resnet-50 was pre-trained on ImageNet images [26] and was transfer trained by unfreezing all the layers with resized B-scans. Binary cross entropy was used as loss function and stochastic gradient descent with initial learning rate ranging from 1e-2 to 1e-6 with momentum 0.9 was used as optimizer. The model was trained with a batch size of 64 and for 31 epochs.

2.3 Student network training

We selected 4 general families of lightweight deep learning architectures for our student networks: SqueezeNet [27], SqueezeNext [28], MobileNet [2931], and ShuffleNet [32,33]. SqueezeNet was shown to obtain AlexNet-level accuracy on ImageNet with 1/50 the parameters, achieved by replacing 3 × 3 convolutions with 1 × 1 convolutions, decreasing the channels input to the 3 × 3 convolutions in “squeeze layers”, and downsampling late in the network so that convolution layers have large activation maps. SqueezeNet can be implemented with residual connections; this modification yields the SqueezeResNet. SqueezeNext achieved AlexNet’s accuracy with only 0.5 million model parameters (approximately 1/112 the parameters of AlexNet). SqueezeNext uses a SqueezeNet architecture as a baseline with the following changes: significantly reducing the total number of parameters used with the 3×3 convolutions, using separable 3 × 3 convolutions to further reduce the model size and removing the additional 1×1 branch after the squeeze module, and using an element-wise addition skip connection similar to that of ResNet architecture. The MobileNet model was shown to perform comparably to larger models, GoogleNet and VGG16, and outperform SqueezeNet in ImageNet classification. It utilizes depth-wise separable convolutions by factorizing a standard convolution into a depthwise convolution and a 1×1 “pointwise” convolution. The original MobileNet model can include faster down-sampling yielding the Fast MobileNet model. ShuffleNet utilizes two new operations, pointwise group convolution and channel shuffle, to greatly reduce computation cost while maintaining accuracy on ImageNet. ShuffleNet was shown to be approximately 13 times faster than AlexNet while maintaining similar accuracy, and performs comparably to MobileNet with a moderate speedup. A total of 27 LWMs from these four general architecture families were selected as the student networks. Details of the individual student models, including hyperparameters, can be found in Supplementary Table 1.

For all models, binary cross entropy was used as the loss function. Adam [34] was used as an optimizer and the batch size was set to 64 and models were trained for 20 epochs. A grid search was performed for learning rate, initial learning rates ranged from 0.1 to 1e-6. The maximum validation accuracy was calculated over epochs and learning rates for each model. The average inference time across all the full 128 B-scan volumes (n=208) was computed on an NVIDIA Tesla P100 graphics processing unit.

2.4 Student-teacher framework - semi-supervised learning

The pretrained ResNet50 model (described in Section 2.2) was used as a teacher network for the student (lightweight) networks. Figure 1 describes the student-teacher framework. Inference with the teacher network was performed on the 478,588 unlabeled images and a label was assigned to each image as the class with the highest value from the softmax in the ResNet50 model, generating hard-targets for the unlabeled images. These images and the inferred labels were then pooled with the expert-labeled images and this pooled dataset was used to train the student models in the student-teacher framework. Additionally we explored the effect of the number of unlabeled images available for student-teacher training on the validation accuracy by subsampling the unlabeled image set at 25%, 50%, 75%, and 100%, where 0% is the baseline performance with only labeled images. The same training hyperparameters and grid search for optimal learning rate were used as in Section 2.3. The maximum validation accuracy was calculated over epochs and learning rates.

 figure: Fig. 1.

Fig. 1. A. Study flow diagram of the study design. The lightweight models were first trained on only the unlabeled images used to train the teacher network (ResNet50). The lightweight models were then trained with labeled and unlabeled images in the student-teacher framework and the three best performing models were selected. After model architecture search, six models were evaluated on the two hold-out test sets. B. Flow diagram of student-teacher framework. Labels for the unlabeled images are inferred (y_pseudo_unlabeled) by the teacher network (top network). The unlabeled images are combined with the expert-labeled images to train the lightweight student networks (bottom network). The inferred labels for the unlabeled images (y_pseudo_unlabeled) and the human-graded labels (y_true_labeled) are used in a binary cross entropy loss to train the lightweight networks (bottom network). The yellow arrows denote training with labeled images and the purple arrows denote the training for unlabeled images.

Download Full Size | PDF

2.5 Alternative student-teacher framework - soft-target training

In this framework, the pretrained ResNet50 model was also used as a teacher network for the student (lightweight) networks. But rather than learning with hard targets alone, a combined loss function was used for soft and hard targets [24]. The soft targets were generated by performing inference on all the images (labeled or unlabeled) with a temperature raised softmax [24]. The temperature (T) and alpha are hyperparameters that must be tuned. We selected the model that benefited most from the student-teacher learning described in Section 2.4 and performed the soft-target training for a number of temperature and alpha values. After selecting the temperature and alpha that yielded the highest validation accuracy, we performed the soft-target learning with the labeled and 100% of the unlabeled images. The maximum validation accuracy over epochs was reported. We refer to this framework as the soft-target student-teacher framework.

2.6 Evaluation on the hold-out test sets

The three top-performing models were selected based on the maximum validation accuracy achieved with the student-teacher framework. These three models with and without student-teacher training (n=6) were evaluated on the two hold-out test sets. Sensitivity and specificity were calculated on the validation set and both hold-out test sets and 95% confidence intervals were calculated with clustered bootstraps.

2.7 Technical details

All analyses were performed using Python (v). Deep learning models were developed using Keras (v), Tensorflow (v), accelerated using NVIDIA CUDA (9.0.333), and trained on a server with dual Xeon 3.4 GHz processors, 256 GB of random access memory, and 8 x NVIDIA P100 GPUs. All LWMs were implemented in Keras adapted from code available from https://github.com/osmr/imgclsmob/tree/master/keras.

3. Results

3.1 Student networks

The inference time versus baseline accuracy tradeoff for the teacher and the student models is presented in Fig. 2. The teacher network (ResNet50) achieves 96.0% validation accuracy with an average inference time of 0.452 seconds. The student models range from 89.6% to 95.1% validation accuracy and average inference time ranges from 0.066 to 0.261 seconds. The best performing (optimizing for inference time and validation accuracy) lightweight model is one of the SqueezeNet models with residual connections (SRN.1). SRN.1 has 95.1% validation accuracy and is 4.13 times faster than ResNet50 with only 1/32th of the parameters.

 figure: Fig. 2.

Fig. 2. Time vs performance tradeoff for the lightweight models and the ResNet50 teacher model. Inference time in seconds is plotted against the maximum validation accuracy across epochs and five repeated training runs. The model family is depicted by the color of the marker with the ResNet50 shown in gray. The size of the marker is scaled by the number of parameters in each model. The lightweight models are abbreviated to SqueezeNet (SN), SqueeseResNet (SRN), MobileNet (M), Fast MobileNet (FM), SqueezeNext (SQN), and ShuffleNet (SFN), see Supplementary Table 1 for more details on abbreviations used.

Download Full Size | PDF

3.2 Student-teacher framework

All student models benefit from training with the additional data provided by the unlabeled images and their inferred labels generated by the teacher network. Figure 3 demonstrates the gradual increase in validation accuracy observed when training with the labeled images and 0%, 25%, 50%, 75%, and 100% of the unlabeled images. Using the student-teacher framework images, the validation accuracies are boosted to between 94.8% and 96.3% looking across all student models. SRN.1, SN.1, SQN.1.3, and SQN.2.3 achieve a validation accuracy of 96.1%, slightly exceeding the validation accuracy of the teacher network. One of the MobileNet models, M.2.1, also narrowly beats the teacher network (ResNet50) with a validation accuracy of 96.3%. SRN.1, SN.1, and M.2.1 were selected as the best performing models as they balance low inference time with high validation accuracy. Figure 4 provides a different view of the results in Fig. 3 by plotting the percentage of unlabeled images used versus validation accuracy. This perspective allows the appreciation of the increase in validation accuracy as more unlabeled images are used. The models, in general, see a more pronounced increase at 25% of the unlabeled images and then a gradual increase after. The 95% confidence intervals demonstrate that the lightweight models benefit from the addition of unlabeled images at all percentages.

 figure: Fig. 3.

Fig. 3. Time vs performance tradeoff for the lightweight models using the student-teacher framework. Inference time in seconds is plotted against the maximum validation accuracy across epochs and five repeated training runs. The model family is depicted by the color of the marker with the ResNet50 shown in gray. The opacity of the markers shows the percentage of the unlabeled data used. The dotted line provides the accuracy of the ResNet50 teacher model. The lightweight models are abbreviated to SqueezeNet (SN), SqueeseResNet (SRN), MobileNet (M), Fast MobileNet (FM), SqueezeNext (SQN), and ShuffleNet (SFN), see Supplementary Table 1 for more details on abbreviations used.

Download Full Size | PDF

 figure: Fig. 4.

Fig. 4. Validation accuracy curves for the lightweight models using the student-teacher framework. Best validation accuracy across epoch and runs is plotted against the percentage of the unlabeled images used, error bars depict the 95% confidence interval. The panels display the results of the various lightweight models trained in each model family.

Download Full Size | PDF

The soft-target student-teacher framework did not yield a similar increase in validation accuracy as compared to the semi-supervised learning student-teach framework (Supplementary Table 2) for the MobileNet model (M.3.4) tested. This model was selected because it had the largest gain in validation accuracy with student-teacher training. The baseline validation accuracy using only the labeled data for training is 89.6%, and when using the soft-target training the validation accuracy rises to 90.3% (T=3, alpha=1.5), far less than the increase observed with the student teacher framework (to 95.1%). Likewise, the soft-target student-teacher framework (T=3, alpha=1.5) with labeled and unlabeled data achieves a lower validation accuracy of 94.4%.

3.3 Performance on hold-out test sets

The ResNet50 teacher model and the three best-performing lightweight models (SRN.1, SN.1, and M.2.1) trained with and without the student-teacher framework were evaluated on two hold-out test sets (Table 2). The sensitivity/specificity on the validation set and both test sets for the ResNet50, SRN.1, SRN.1 trained with the student-teacher framework (SRN.1+ST), SN.1, SN.1 trained with the student-teacher framework (SN.1+ST), M.2.1, and M.2.1 trained with the student-teacher framework (M.2.1+ST) are presented in Table 2. The ResNet50 and all lightweight models trained with the student teacher framework achieve comparable sensitivity (92.55 to 94.68) and specificity (93.67 to 96.38) values on the hold-out test set from the CIRRUS 4000/5000. The lightweight models perform slightly worse in terms of sensitivity (86.92 to 89.96) without the student-teacher framework, while the specificity was comparable (95.60 to 96.43). On the CIRRUS 6000 hold-out test set, ResNet50 and SRN.1+ST models again achieve similar sensitivity (83.90 versus 84.38) and specificity (90.75 versus 90.81) levels, while SN.1+ST and M.2.1+ST have lower sensitivity (77.63 and 81.38, respectively) and higher specificity (93.05 and 93.23, respectively). The lightweight models trained without the student-teacher framework achieve lower sensitivities (71.71 to 78.71), but have comparable specificities (92.00 to 92.70) as compared to the ResNet50 and the lightweight models trained with the student-teacher framework.

Tables Icon

Table 2. Sensitivity and specificity on validation and hold-out test set with 95% confidence intervals (95% CI), best performing results are bolded.

4. Discussion

In this work, we have demonstrated that lightweight deep learning networks designed for use on mobile devices can be applied to a medical imaging classification task. In particular, we have shown that four general families of LWMs, SqueezeNet, SqueezeNext, ShuffleNet, and MobileNet can be used to perform B-scan of interest identification with retina OCT images. Using a student-teacher framework, we demonstrate that all lightweight, student networks have improved validation accuracy with the addition of teacher-generated labeled data. Most notably, several models are able to perform comparably to the state-of-the-art, large teacher model (ResNet50) and run in a fraction of the time, making the LWMs ideal candidates to run in clinical settings where runtime is important and computing resources are often limited. The student-teacher framework is an appealing paradigm for semi-supervised learning in medical image classification as the generation of expertly-labeled image datasets is resource intensive.

All of the lightweight models benefited from the student-teacher training, with some models gaining as much as 4-5% in validation accuracy. The best-performing lightweight models, two SqueezeNet models (SRN.1 and SN.1) and one MobileNet model (M.2.1), and the ResNet50 teacher network were evaluated in two hold-out test sets. The results from both hold-out test sets demonstrate that the lightweight models, when trained with the student-teacher framework, perform comparably to the ResNet50 model on the hold-out test sets. In both hold-out test sets, the lightweight models trained without the student-teacher framework perform similarly in terms of specificity to the other models, but have lower sensitivities, demonstrating the benefit of harnessing the unlabeled image for training the lightweight models. The sensitivity and specificity were lower for all models for the CIRRUS 6000 hold-out test set as compared to the CIRRUS 4000/5000 hold-out test set. This is not surprising as the models were trained with data from CIRRUS 4000/5000 instruments, but does demonstrate that deep learning models do have difficulty generalizing across imaging platforms. Additionally, the Cohen’s kappa between raters was lower for the CIRRUS 6000 hold-out test set, suggesting that the ground-truth labels were noisier than the CIRRUS 4000/5000 hold-out test set. We plan to investigate increasing generalization by utilizing unlabeled B-scans from the CIRRUS 6000 in our student-teacher framework in future work.

We found the soft-target student-teacher framework was less effective than the semi-supervised approach for this binary classification problem. This is likely due to the fact that with only two classes there is little information added by encouraging the student model to mimic the teacher’s behavior on both the soft and hard targets. It could be argued there are no “off-target” labels in the binary case, as binary cross entropy yields a “probability” p for one class and (1-p) for the other class. The student-teacher framework with unlabeled images presents a way to utilize the many unlabeled images available in medical imaging. We did not explore the size of the labeled image dataset, as we treated the teacher network as a fixed model. In future studies, it would be illuminating to explore the minimal requirements for the labeled imaging set size.

There are relatively few examples of LWMs being applied to the classification of medical images, but in general they have been shown to perform comparably to or better than larger, state-of-the-art models. Previously a lightweight deep learning model was trained to detect the 12 views of transthoracic echocardiography with a student-teacher framework with soft-target training [20]. The LWM had comparable performance to the best available deep learning architectures, but with 1% of the number of parameters and a 6X faster inference time. In addition, a lightweight model has been shown to be effective in detecting fundus image quality and was fast enough to run on mobile devices [18]. Two LWMs were shown to perform comparably to both a ResNet50 and an Inceptionv3 network on disease classification using chest x-ray and fundus photo images [15]. The authors suggest that many state-of-the-art algorithms designed for object recognition may be overparameterized for medical imaging classification tasks. A LWM (with only 6.9% of parameters compared to the ResNet-50 model) outperformed the ResNet-50 model on a multiclass classification task on OCT images [35], and a LWM trained to predict brain age from magnetic resonance images outperformed state-of-the-art, heavyweight models [19]. Most recently, lightweight models have been able to detect COVID-19 infections from chest X-ray images with high accuracy [21,36,37].

A popular approach to increasing the performance of deep learning classifiers is to use transfer learning [12,38], in which a large multiclass image classifier is pretrained on a natural image dataset and then retrained on the smaller medical imaging dataset of interest. The resulting model has many fewer parameters that need to be trained than the original model and is “fine-tuned” to the medical imaging task. There may be drawbacks to this approach, as large image classifier models such as ResNet are designed for many more classes than are typical for medical imaging tasks, and medical imaging pathologies often contain unique features that differ from standard natural images [15]. Additionally, most, if not all, natural image datasets are RGB images, while many medical images are grayscale, thus requiring the medical images to be duplicated across channels. The student-teacher approach avoids these drawbacks. The large teacher model, although similar to those used in transfer learning, is not retrained. Instead, the teacher model passes its classification knowledge to the lightweight student model by providing “examples'’ in the form of labeled training data. The student model is only trained on images that are specific to the task. Additionally, the expansion of the training examples available to the student networks should prevent overfitting and result in more generalized models.

There have been recent publications exploring knowledge distillation for classification, semantic segmentation, and instance segmentation with medical images. Knowledge distillation was demonstrated to outperform transfer learning in a diabetic retinopathy classification task using color fundus photography [39]. Multiple teacher networks were used to train a single student network to perform a multi-task classification with color fundus photographs [40]. A Multiple-Instance Learning (MIL) model combining pseudo-labeling via a mean-teacher was effective at detecting intraretinal fluid, subretinal fluid and pigment epithelial detachments while cutting the requirement for expensive labels by 94.22% [41]. Likewise, using limited labeled data, retinal layers were segmented from OCT images using a student-teacher framework with performance comparable to a model trained in a fully-supervised fashion [42]. A chain of student networks, each distilling knowledge to the next, was shown to need only 0.5% of the labeled data available to accurately detect colorectal cancer from histology slides [43]. Additionally enforcing consistency across perturbations with unlabeled data has been shown to improve multi-class classification results with chest X-ray [44] and instance segmentation results from microscopy and computed tomography images [45].

As discussed above, one limitation of our study is that we did not investigate reducing the number of labeled images because the labeled images went into training the teacher ResNet50 model as well as training the LWMs. A follow-up study with two sets of labeled images, one used for teacher training and one used for student training, would be an interesting paradigm in which to investigate the number of labeled images needed to achieve comparable accuracy. It is difficult to compare the models as they each have their own architectures and were only tested on one binary classification task; more research is warranted before making a more generalized statement about model superiority. A drawback to knowledge distillation is that there is the potential that the student network can learn to make systematic errors from the teacher model, so-called ‘confirmation bias’. A recent learning framework was proposed to reduce confirmation bias by doing away with a teacher model and rather distilling knowledge between student networks [46]. Future directions of this work are to explore ensembling multiple teachers, deep supervision, and teacher-free knowledge distillation.

This work represents a divergent perspective on training lightweight deep learning models for medical imaging classification tasks. Expertly-labeled images are expensive to obtain; we exploited a large number of unlabeled images using a student-teacher network to perform semi-supervised learning. In the end, we were able to achieve comparable performance for fast LWMs as compared to the relatively slow and heavyweight ResNet50 model. The framework presented here could expand the horizons for running medical imaging classification algorithms on imaging instruments or mobile devices.

Funding

Research to Prevent Blindness; Carl Zeiss Meditec, Inc.; National Institute on Aging (R01AG060942); National Eye Institute (K23EY029246).

Disclosures

Niranchana Manivannan: Carl Zeiss Meditec, Inc., Dublin, CA (E); Gary C. Lee: Carl Zeiss Meditec, Inc., Dublin, CA (E); Sophia Yu: Carl Zeiss Meditec, Inc., Dublin, CA (E); Mary K. Durbin: Carl Zeiss Meditec, Inc., Dublin, CA (E); Aditya Nair: Carl Zeiss Meditec, Inc., Dublin, CA (E); Rishi P. Singh: Regeneron(C), Genentech(C), Alcon(C), Novartis(C), Bausch and Lomb(C), Gyroscope (C), Apellis(F), Graybug(F), Aerie (F); Katherine Talcott: Zeiss (F), Roche/Genentech (R); Aaron Lee: Carl Zeiss Meditec(F), Novartis(F), Genentech(C), Verana Health(C), US FDA(E), Santen(F), Microsoft(F), NVIDIA(F), Topcon(R)

Data availability

Data underlying the results presented in this paper are not publicly available at this time due to privacy reasons.

Supplemental document

See Supplement 1 for supporting content.

References

1. M. Adhi and J. S. Duker, “Optical coherence tomography – current and future applications,” Current Opinion in Ophthalmology 24(3), 213–221 (2013). [CrossRef]  

2. D. Ferrara, R. E. Silver, R. N. Louzada, E. A. Novais, G. K. Collins, and J. M. Seddon, “Optical coherence tomography features preceding the onset of advanced age-related macular degeneration,” Invest. Ophthalmol. Vis. Sci. 58(9), 3519 (2017). [CrossRef]  

3. G. Trichonas and P. K. Kaiser, “Optical coherence tomography imaging of macular oedema,” Br. J. Ophthalmol. 98(Suppl 2), ii24–ii29 (2014). [CrossRef]  

4. I. I. Bussel, G. Wollstein, and J. S. Schuman, “OCT for glaucoma diagnosis, screening and detection of glaucoma progression,” Br. J. Ophthalmol. 98(Suppl 2), ii15–9 (2014). [CrossRef]  

5. R. T. Yanagihara, C. S. Lee, D. S. W. Ting, and A. Y. Lee, “Methodological challenges of deep learning in optical coherence tomography for retinal diseases: a review,” Trans. Vis. Sci. Tech. 9(2), 11 (2020). [CrossRef]  

6. C. S. Lee, D. M. Baughman, and A. Y. Lee, “Deep learning is effective for the classification of OCT images of normal versus age-related macular degeneration,” Ophthalmol Retina 1(4), 322–327 (2017). [CrossRef]  

7. C. S. Lee, A. J. Tyring, N. P. Deruyter, Y. Wu, A. Rokem, and A. Y. Lee, “Deep-learning based, automated segmentation of macular edema in optical coherence tomography,” Biomed. Opt. Express 8(7), 3440–3448 (2017). [CrossRef]  

8. A. C. Thompson, A. A. Jammal, and F. A. Medeiros, “A review of deep learning for screening, diagnosis, and detection of glaucoma progression,” Trans. Vis. Sci. Tech. 9(2), 42 (2020). [CrossRef]  

9. M. D. Abràmoff, P. T. Lavin, M. Birch, N. Shah, and J. C. Folk, “Pivotal trial of an autonomous AI-based diagnostic system for detection of diabetic retinopathy in primary care offices,” npj Digital Med 1(1), 39 (2018). [CrossRef]  

10. F. Li, H. Chen, Z. Liu, X.-D. Zhang, M.-S. Jiang, Z.-Z. Wu, and K.-Q. Zhou, “Deep learning-based automated detection of retinal diseases using optical coherence tomography images,” Biomed. Opt. Express 10(12), 6204–6226 (2019). [CrossRef]  

11. A. M. Alqudah, “AOCT-NET: a convolutional network automated classification of multiclass retinal diseases using spectral-domain optical coherence tomography images,” Med. Biol. Eng. Comput. 58(1), 41–53 (2020). [CrossRef]  

12. W. Lu, Y. Tong, Y. Yu, Y. Xing, C. Chen, and Y. Shen, “Deep learning-based automated classification of multi-categorical abnormalities from optical coherence tomography images,” Trans. Vis. Sci. Tech. 7(6), 41 (2018). [CrossRef]  

13. S. Yu, H. Ren, N. Manivannan, G. Lee, P. Sha, A. Melo, T. Conti, T. Greenlee, E. R. Chen, K. Talcott, R. P. Singh, N. D’Souza, and M. Durbin, “Performance validation of B-scan of interest algorithm on normative dataset,” Invest. Ophthalmol. Vis. Sci. 61, PB0085 (2020).

14. H. Ren, N. Manivannan, G. C. Lee, S. Yu, P. Sha, T. Conti, A. Melo, T. Greenlee, E. Chen, K. Talcott, R. P. Singh, M. K. Durbin, and N. D’Souza, “Improving OCT B-scan of interest inference performance using TensorRT based neural network optimization,” Invest. Ophthalmol. Vis. Sci. 61, 1635 (2020).

15. M. Raghu, C. Zhang, J. Kleinberg, and S. Bengio, “Transfusion: understanding transfer learning for medical imaging,” in Advances in Neural Information Processing Systems32, H. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. Fox, and R. Garnett, eds. (Curran Associates, Inc., 2019), pp. 3347–3357.

16. S. Kuwayama, Y. Ayatsuka, D. Yanagisono, T. Uta, H. Usui, A. Kato, N. Takase, Y. Ogura, and T. Yasukawa, “Automated detection of macular diseases by optical coherence tomography and artificial intelligence machine learning of optical coherence tomography images,” J. Ophthalmol. 2019, 6319581 (2019). [CrossRef]  

17. L. Greco, G. Percannella, P. Ritrovato, F. Tortorella, and M. Vento, “Trends in IoT based solutions for health care: Moving AI to the edge,” Pattern Recognition Letters 135, 346–353 (2020). [CrossRef]  

18. A. D. Pérez, O. Perdomo, and F. A. González, “A lightweight deep learning model for mobile eye fundus image quality assessment,” in 15th International Symposium on Medical Information Processing and Analysis (International Society for Optics and Photonics, 2020), Vol. 11330, p. 113300K.

19. H. Peng, W. Gong, C. F. Beckmann, A. Vedaldi, and S. M. Smith, “Accurate brain age prediction with lightweight deep neural networks,” Med. Image Anal. 68, 101871 (2021). [CrossRef]  

20. H. Vaseli, Z. Liao, A. H. Abdi, and H. Girgis, “Designing lightweight deep learning models for echocardiography view classification,” Medical Imaging (2019).

21. S. Karakanis and G. Leontidis, “Lightweight deep learning models for detecting COVID-19 from chest X-ray images,” Comput. Biol. Med. 130, 104181 (2021). [CrossRef]  

22. I. Zeki Yalniz, Hervé Jégou, Kan Chen, Manohar Paluri, and Dhruv Mahajan, “Billion-scale semi-supervised learning for image classification,” Computer Vision and Pattern Recognition (2019).

23. C. Buciluǎ, R. Caruana, and A. Niculescu-Mizil, “Model compression,” Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining - KDD ‘06 (2006).

24. G. Hinton, O. Vinyals, and J. Dean, “Distilling the knowledge in a neural network,” arXiv [stat.ML] (2015).

25. K. He, X. Zhang, S. Ren, and J. Sun, “Deep residual learning for image recognition,” in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (2016), pp. 770–778.

26. J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, and L. Fei-Fei, “ImageNet: A large-scale hierarchical image database,” 2009 IEEE Conference on Computer Vision and Pattern Recognition (2009).

27. F. N. Iandola, S. Han, M. W. Moskewicz, and K. Ashraf, “SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and< 0.5 MB model size,” arXiv preprint arXiv (2016).

28. A. Gholami, K. Kwon, B. Wu, and Z. Tai, “Squeezenext: Hardware-aware neural network design,” Proceedings of the (2018).

29. A. G. Howard, M. Zhu, B. Chen, D. Kalenichenko, W. Wang, T. Weyand, M. Andreetto, and H. Adam, “MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications,” arXiv [cs.CV] (2017).

30. M. Sandler, A. Howard, M. Zhu, A. Zhmoginov, and L.-C. Chen, “Mobilenetv2: Inverted residuals and linear bottlenecks,” in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (2018), pp. 4510–4520.

31. A. Howard, M. Sandler, G. Chu, L.-C. Chen, B. Chen, M. Tan, W. Wang, Y. Zhu, R. Pang, and V. Vasudevan, and Others, “Searching for mobilenetv3,” in Proceedings of the IEEE/CVF International Conference on Computer Vision (2019), pp. 1314–1324.

32. X. Zhang, X. Zhou, M. Lin, and J. Sun, “Shufflenet: an extremely efficient convolutional neural network for mobile devices,” Proc. IEEE (2018).

33. N. Ma, X. Zhang, H.-T. Zheng, and J. Sun, “Shufflenet v2: Practical guidelines for efficient CNN architecture design,” in Proceedings of the European Conference on Computer Vision (ECCV) (2018), pp. 116–131.

34. D. P. Kingma and J. Ba, “Adam: a method for stochastic optimization,” arXiv [cs.LG] (2014).

35. S. A P. S. Kar, G. S. V. P. Gopi, and P. Palanisamy, “OctNET: a lightweight CNN for retinal disease classification from optical coherence tomography images,” Computer Methods and Programs in Biomedicine 200, 105877 (2021). [CrossRef]  

36. S. R. Abdani, M. A. Zulkifley, and N. H. Zulkifley, “A lightweight deep learning model for COVID-19 detection,” 2020 IEEE Symposium on Industrial Electronics & Applications (ISIEA) (2020).

37. N. Awasthi, A. Dayal, L. R. Cenkeramaddi, and P. K. Yalavarthy, “Mini-COVIDNet: efficient lightweight deep neural network for ultrasound based point-of-care detection of COVID-19,” IEEE Trans. Ultrason., Ferroelect., Freq. Contr 68(6), 2023–2037 (2021). [CrossRef]  

38. D. S. Kermany, M. Goldbaum, W. Cai, C. C. S. Valentim, H. Liang, S. L. Baxter, A. McKeown, G. Yang, X. Wu, F. Yan, J. Dong, M. K. Prasadha, J. Pei, M. Y. L. Ting, J. Zhu, C. Li, S. Hewett, J. Dong, I. Ziyar, A. Shi, R. Zhang, L. Zheng, R. Hou, W. Shi, X. Fu, Y. Duan, V. A. N. Huu, C. Wen, E. D. Zhang, C. L. Zhang, O. Li, X. Wang, M. A. Singer, X. Sun, J. Xu, A. Tafreshi, M. A. Lewis, H. Xia, and K. Zhang, “Identifying Medical Diagnoses and Treatable Diseases by Image-Based Deep Learning,” Cell 172(5), 1122–1131.e9 (2018). [CrossRef]  

39. S. Akbarian, L. Seyyed-Kalantari, F. Khalvati, and E. Dolatabadi, “Evaluating knowledge transfer in neural network for medical images,” arXiv [eess.IV] (2020).

40. S. Chelaramani, M. Gupta, V. Agarwal, P. Gupta, and R. Habash, “Multi-task knowledge distillation for eye disease prediction,” in Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (2021), pp. 3983–3993.

41. S. Reiß, C. Seibold, A. Freytag, E. Rodner, and R. Stiefelhagen, “Every annotation counts: Multi-label deep supervision for medical image segmentation,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2021), pp. 9532–9542.

42. S. Sedai, B. Antony, R. Rai, K. Jones, H. Ishikawa, J. Schuman, W. Gadi, and R. Garnavi, “Uncertainty guided semi-supervised segmentation of retinal layers in OCT Images,” in Medical Image Computing and Computer Assisted Intervention–MICCAI 201911764 (Springer International Publishing, 2019), pp. 282–290. [CrossRef]  

43. S. Shaw, M. Pajak, A. Lisowska, S. A. Tsaftaris, and A. Q. O’Neil, “Teacher-Student chain for efficient semi-supervised histology image classification,” arXiv [cs.CV] (2020).

44. Q. Liu, L. Yu, L. Luo, Q. Dou, and P. A. Heng, “Semi-supervised medical image classification with relation-driven self-ensembling model,” IEEE Trans. Med. Imaging 39(11), 3429–3440 (2020). [CrossRef]  

45. H.-Y. Zhou, C. Wang, H. Li, G. Wang, S. Zhang, W. Li, and Y. Yu, “SSMD: semi-supervised medical image detection with adaptive consistency and heterogeneous perturbation,” Med. Image Anal. 72, 102117 (2021). [CrossRef]  

46. B. Unnikrishnan, C. M. Nguyen, S. Balaram, C. S. Foo, and P. Krishnaswamy, “Semi-supervised classification of diagnostic radiographs with noteacher: a teacher that is not mean,” in Medical Image Computing and Computer Assisted Intervention– MICCAI 202012261 (Springer International Publishing, 2020), pp. 624–634. [CrossRef]  

Supplementary Material (1)

NameDescription
Supplement 1       Supplement

Data availability

Data underlying the results presented in this paper are not publicly available at this time due to privacy reasons.

Cited By

Optica participates in Crossref's Cited-By Linking service. Citing articles from Optica Publishing Group journals and other participating publishers are listed here.

Alert me when this article is cited.


Figures (4)

Fig. 1.
Fig. 1. A. Study flow diagram of the study design. The lightweight models were first trained on only the unlabeled images used to train the teacher network (ResNet50). The lightweight models were then trained with labeled and unlabeled images in the student-teacher framework and the three best performing models were selected. After model architecture search, six models were evaluated on the two hold-out test sets. B. Flow diagram of student-teacher framework. Labels for the unlabeled images are inferred (y_pseudo_unlabeled) by the teacher network (top network). The unlabeled images are combined with the expert-labeled images to train the lightweight student networks (bottom network). The inferred labels for the unlabeled images (y_pseudo_unlabeled) and the human-graded labels (y_true_labeled) are used in a binary cross entropy loss to train the lightweight networks (bottom network). The yellow arrows denote training with labeled images and the purple arrows denote the training for unlabeled images.
Fig. 2.
Fig. 2. Time vs performance tradeoff for the lightweight models and the ResNet50 teacher model. Inference time in seconds is plotted against the maximum validation accuracy across epochs and five repeated training runs. The model family is depicted by the color of the marker with the ResNet50 shown in gray. The size of the marker is scaled by the number of parameters in each model. The lightweight models are abbreviated to SqueezeNet (SN), SqueeseResNet (SRN), MobileNet (M), Fast MobileNet (FM), SqueezeNext (SQN), and ShuffleNet (SFN), see Supplementary Table 1 for more details on abbreviations used.
Fig. 3.
Fig. 3. Time vs performance tradeoff for the lightweight models using the student-teacher framework. Inference time in seconds is plotted against the maximum validation accuracy across epochs and five repeated training runs. The model family is depicted by the color of the marker with the ResNet50 shown in gray. The opacity of the markers shows the percentage of the unlabeled data used. The dotted line provides the accuracy of the ResNet50 teacher model. The lightweight models are abbreviated to SqueezeNet (SN), SqueeseResNet (SRN), MobileNet (M), Fast MobileNet (FM), SqueezeNext (SQN), and ShuffleNet (SFN), see Supplementary Table 1 for more details on abbreviations used.
Fig. 4.
Fig. 4. Validation accuracy curves for the lightweight models using the student-teacher framework. Best validation accuracy across epoch and runs is plotted against the percentage of the unlabeled images used, error bars depict the 95% confidence interval. The panels display the results of the various lightweight models trained in each model family.

Tables (2)

Tables Icon

Table 1. Study data used for training and validation.

Tables Icon

Table 2. Sensitivity and specificity on validation and hold-out test set with 95% confidence intervals (95% CI), best performing results are bolded.

Select as filters


Select Topics Cancel
© Copyright 2024 | Optica Publishing Group. All rights reserved, including rights for text and data mining and training of artificial technologies or similar technologies.