导航菜单
首页 >  神经风格迁移 TensorFlow Core  > 使用 TensorFlow Hub 进行迁移学习  

使用 TensorFlow Hub 进行迁移学习  

在 TensorFlow.org 上查看 在 Google Colab 中运行在 GitHub 中查看源代码 下载笔记本 查看 TF Hub 模型

TensorFlow Hub 是预训练的 TensorFlow 模型的仓库。

此教程演示了如何执行以下操作:

将来自 TensorFlow Hub 的模型与 tf.keras 结合使用使用来自 TensorFlow Hub 的图像分类模型进行简单的迁移学习,针对您自己的图像类微调模型设置import numpy as npimport timeimport PIL.Image as Imageimport matplotlib.pylab as pltimport tensorflow as tfimport tensorflow_hub as hubimport datetime%load_ext tensorboard2023-11-07 23:03:51.866811: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered2023-11-07 23:03:51.866865: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered2023-11-07 23:03:51.868626: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registeredImageNet 分类器

您将首先使用预训练的分类器模型获取图像并预测它是什么图像 - 无需训练!

下载分类器

从 TensorFlow Hub 中选择一个 MobileNetV2 预训练模型,并将其封装为带有 hub.KerasLayer 的 hub.KerasLayer 层。可以在这里使用任何来自 TensorFlow Hub 的兼容的图像分类器模型,包括下面下拉列表中提供的示例。

mobilenet_v2 ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"inception_v3 = "https://tfhub.dev/google/imagenet/inception_v3/classification/5"classifier_model = mobilenet_v2IMAGE_SHAPE = (224, 224)classifier = tf.keras.Sequential([hub.KerasLayer(classifier_model, input_shape=IMAGE_SHAPE+(3,))])对单个图像运行分类器

下载要在模型上尝试的单个图像。

grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)grace_hopperDownloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg61306/61306 [==============================] - 0s 0us/step

png

grace_hopper = np.array(grace_hopper)/255.0grace_hopper.shape(224, 224, 3)

添加批量维度(使用 np.newaxis)并将图像传递给模型:

result = classifier.predict(grace_hopper[np.newaxis, ...])result.shape1/1 [==============================] - 2s 2s/step(1, 1001)

结果是一个 1001 元素的 logits 向量,同时对图像属于每个类别的概率进行评分。

顶部类 ID 可以通过 tf.math.argmax 找到:

predicted_class = tf.math.argmax(result[0], axis=-1)predicted_class解码预测

获取 predicted_class ID(例如 653)并获取 ImageNet 数据集标签以解码预测:

labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')imagenet_labels = np.array(open(labels_path).read().splitlines())Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt10484/10484 [==============================] - 0s 0us/stepplt.imshow(grace_hopper)plt.axis('off')predicted_class_name = imagenet_labels[predicted_class]_ = plt.title("Prediction: " + predicted_class_name.title())

png

简单的迁移学习

但是,如果您想使用自己的数据集创建一个自定义分类器,但该数据集的类未包含在原始 ImageNet 数据集中(预训练模型已基于该数据集进行训练),此时该如何处理?

为此,您可以:

从 TensorFlow Hub 中选择一个预训练模型;重新训练顶部(最后一个)层以识别自定义数据集中的类。数据集

在本例中,您将使用 TensorFlow 花卉数据集:

import pathlibdata_file = tf.keras.utils.get_file( 'flower_photos.tgz', 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', cache_dir='.',extract=True)data_root = pathlib.Path(data_file).with_suffix('')Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz228813984/228813984 [==============================] - 1s 0us/step

首先,使用 tf.keras.utils.image_dataset_from_directory 将磁盘上的图像数据加载到模型中,这将生成一个 tf.data.Dataset:

batch_size = 32img_height = 224img_width = 224train_ds = tf.keras.utils.image_dataset_from_directory( str(data_root), validation_split=0.2, subset="training", seed=123, image_size=(img_height, img_width), batch_size=batch_size)val_ds = tf.keras.utils.image_dataset_from_directory( str(data_root), validation_split=0.2, subset="validation", seed=123, image_size=(img_height, img_width), batch_size=batch_size)Found 3670 files belonging to 5 classes.Using 2936 files for training.Found 3670 files belonging to 5 classes.Using 734 files for validation.

花卉数据集有五个类。

class_names = np.array(train_ds.class_names)print(class_names)['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']

其次,由于 TensorFlow Hub 对图像模型的约定是期望浮点输入在 [0, 1] 范围内,因此使用 tf.keras.layers.Rescaling 预处理层来实现这一点。

注:您还可以在模型中包含 tf.keras.layers.Rescaling 层。有关权衡的讨论,请参阅使用预处理层指南。

normalization_layer = tf.keras.layers.Rescaling(1./255)train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.

第三,通过使用 Dataset.prefetch 的缓冲预提取来完成输入流水线,这样您就可以从磁盘产生数据而不会出现 I/O 阻塞问题。

这些是加载数据时应该使用的一些最重要的 tf.data 方法。感兴趣的读者可以在使用 tf.data API 获得更高性能指南中了解有关它们的

相关推荐: