欢迎光临
我们一直在努力

半监督学习数据挖掘,非监督学习的样本数据要求带标签

1. 背景

? ? ??当标注数据较少,而未标注的数据很多,并且标注成本很高时,可以考虑半监督学习训练。首先,采用伪标签技术把没有标注的的图片打上伪标签,然后用标注数据和伪标签数据混合训练模型。值得注意的是,要保证每个mini-batch中含有真实标签和伪标签,本文带你用代码实现。

2. 实现方法及步骤

? ? ? ?首先看看伪标签技术,参考这里,如下图所示:

3. 代码实现

? ? ? 首先是生成伪标签,对于分类和目标检测而言都比较简单,这里不赘述。

? ? ? 下面实现的是:如何在每个mini-batch中保证同时存在真实标签和伪标签,并且控制他们的比例,以分类为例进行说明。

? ? ? ?第一步,需要修稿数据加载程序,如下:

import osimport torchfrom torch.utils import dataimport numpy as npfrom torchvision import transforms as Timport torchvisionimport cv2import sysimport randomfrom PIL import Imagefrom data_augment import gussian_blur, random_cropclass Dataset(data.Dataset): def __init__(self, img_list, img_list1, phase=’train’): self.phase = phase # 标注的标签 with open(img_list, ‘r’) as fd: imgs = fd.readlines() imgs = [img.rstrip(“\n”) for img in imgs] random.shuffle(imgs) self.imgs = imgs # 伪标签(模拟的) with open(img_list1, ‘r’) as fd: fake_imgs = fd.readlines() fake_imgs = [img.rstrip(“\n”) for img in fake_imgs] random.shuffle(fake_imgs) self.fake_imgs = fake_imgs normalize = T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) if self.phase == ‘train’: self.transforms = T.Compose([ T.RandomHorizontalFlip(), T.ToTensor(), normalize ]) else: self.transforms = T.Compose([ T.ToTensor(), normalize ]) def __getitem__(self, index): sample = self.imgs[index] splits = sample.split() img_path = splits[0] # data augment data = cv2.imread(img_path) data = random_crop(data, 0.2) data = gussian_blur(data, 0.2) data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB) data = Image.fromarray(data) data = data.resize((224, 224)) data = self.transforms(data) label = np.int32(splits[1]) # 取伪数据和伪标签 fake_datas, fake_labels = [], [] for i in range(2): fake_sample = self.fake_imgs[(index+i)%len(self.fake_imgs)] fake_splits = fake_sample.split() fake_img_path = fake_splits[0] fake_data = cv2.imread(fake_img_path) fake_data = cv2.cvtColor(fake_data, cv2.COLOR_BGR2RGB) fake_data = Image.fromarray(fake_data) fake_data = fake_data.resize((224, 224)) fake_data = self.transforms(fake_data) fake_label = np.int32(fake_splits[1]) fake_datas.append(fake_data.float()) fake_labels.append(fake_label) return data.float(), label, fake_datas, fake_labels def __len__(self): return len(self.imgs)

? ? ? ?第二步,在训练主程序中的实现,如下:

def train(epoch, net, trainloader, optimizer, criterion): print(‘\nEpoch: %d’ % epoch) net.train() train_loss = 0 correct = 0 total = 0 batch_id = 0 for (inputs, targets, fake_inputs, fake_targets) in tqdm(trainloader): # 将真标签和伪标签融合 fake_inputs.append(inputs) fake_targets.append(targets) inputs = torch.cat(fake_inputs, 便宜香港vps dim=0) targets = torch.cat(fake_targets, dim=0) inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets.long()) loss.backward() optimizer.step() train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets.long()).sum().item() iters = epoch * len(trainloader) + batch_id if iters % 10 == 0: acc = predicted.eq(targets.long()).sum().item()*1.0/targets.shape[0] los = loss*1.0/targets.shape[0] #tensor_board.visual_loss(“train_loss”, los, iters) #tensor_board.visual_acc(“train_acc”, acc, iters) batch_id += 1

?就是这么简单,理论部分请参考我的另一篇博客

相关:https://blog.csdn.net/p_lart/article/details/100128353

36385541

赞(0)
【声明】:本博客不参与任何交易,也非中介,仅记录个人感兴趣的主机测评结果和优惠活动,内容均不作直接、间接、法定、约定的保证。访问本博客请务必遵守有关互联网的相关法律、规定与规则。一旦您访问本博客,即表示您已经知晓并接受了此声明通告。