使用迁移学习快速训练识别特定风格的图片

使用迁移学习快速训练识别特定风格的图片

前几天接到一个任务,需要从我们app的feed流中的筛选一些「优质」图片,作为运营同学的精选feed候选池。这里「优质」的参考就是以前运营同学手工筛序的精选feed图片。问题并不难,最容易想到的方向有两个:

  1. 机器学习方向,训练一个能够识别这种「优质」风格图片的模型。
  2. 过滤推荐方向,利用用户来测试feed图片质量(根据点赞、评论、观看张数、停留时间等指标),使用用户来筛选优质feed图片(用户的偏好千奇百怪,筛选结果可能未必如你所想,典型如今日头条……)。

今天我们介绍如何使用机器学习解决这个问题。具体来讲,由于时间紧,任务重,我们决定使用迁移学习来完成这个任务。后面如果有时间,我们也会尝试一下使用用户来过滤和筛选优质图片。

什么是迁移学习

迁移学习 (Transfer learning) 顾名思义就是就是把已学训练好的模型参数迁移到新的模型来帮助新模型训练。考虑到大部分数据或任务是存在相关性的,所以通过迁移学习我们可以将已经学到的模型参数(也可理解为模型学到的知识)通过某种方式来分享给新模型从而加快并优化模型的学习效率不用像大多数网络那样从零学习。

为什么使用迁移学习

  • 很多时候,你可能并没有足够大的数据集来训练模型,更不用说带有高质量标签的数据集了。使用已经训练好的网络,可以降低用于训练的数据集大小要求。
  • 从零开始训练一个深度网络是非常消耗算力和时间的。如果再将模型调整、超参数调整等有点玄学的流程加进去,消耗的时间会更多。对于创业公司来说,很多时候是很难给出这么多的时间预算来解决一个模型问题的。
  • 基于迁移学习训练一个模型往往只需要训练有限的几层网络,或者使用已有网络作为特征生成器,使用常规机器学习方法(如svm)来训练分类器。整体训练时间大幅降低。效果可能不是最好的,但是往往能够在短时间内帮你训练出一个够用的模型,解决当前的实际问题。

也就是说,近几年深度学习的各种突破本质上还是建立在数据集的完善和算力的提升。算法方面的提升带来的突破其实不如前两者明显。如果你是一个开发者,具体到要使用机器学习解决特定问题的时候,你一定想清楚你能否搞定数据集和算力的问题,如果不能,不妨尝试一下迁移学习。

如何进行迁移学习

我们的任务是筛选优质feed图片,其实就是一个优质图片与普通图片的二分类问题。

运营给出的「优质」参考图片:

直观感受是,健身摆拍图、美食图和少量风光照是她们眼中的优质图片?

运营给出的「普通」参考图片:

直观感受是,屏幕截图和没什么特点的图片被认为是普通图片。

我们迁移学习的过程就是复用训练好的(部分)网络和权重,然后构建我们自己的模型进行训练:

迁移学习在选择预训练网络时有一点需要注意:预训练网络与当前任务差距不大,否则迁移学习的效果会很差。这里根据我们的任务类型,我们选择了深度残差网络 ResNet50, 权重选择imagenet数据集。选择 RetNet 的主要原因是之前我们训练的图片鉴黄模型是参考雅虎开源的 open NSFW , 而这个模型使用的就是残差网络,模型效果让我们影响深刻。完整代码如下(keras + tensorflow):

  • 这里我们仅重新训练了输出层,你也可以根据自己需要添加多个自定义层。
  • 整个训练过程非常快,在Macbook late 2013仅使用CPU训练的情况下,不到一个小时收敛到了82%的准确率。考虑到我们的「优质」图片标签质量不太高的实际情况,这个准确率是可以接受的。
  • 完成训练后,我们使用该模型对生产环境的2000张实时图片进行了筛选,得到85张图片,运营主观打分结果是~50%可用,~25%需要结合多图考虑,其他不符合要求。考虑到我们的任务是辅助他们高效发现和筛选潜在优质图片,这个结果他们还是认可的。部分筛选结果如下:

还可以更简单一点吗?

如果你觉得上面重新训练网络还是太慢、太繁琐,我们还有更简单的迁移学习的方法:将预训练网络作为特征提取器,然后使用机器学习方法来训练分类器。以SVM为例,完成迁移学习只需要两个步骤:

  • 将预训练网络最后一层输出作为特征提取出来:
resnet_model = None

def extract_resnet(x):
  '''
  :param x: images numpy array 
  :return: features
  '''
  global resnet_model
  if resnet_model is None:
    resnet_model = ResNet50(include_top=False,
                            weights='imagenet',
                            input_shape=(image_h, image_w, 3))
  features_array = resnet_model.predict(x)
  return np.squeeze(features_array)

  • 使用特征训练SVM分类器:
def train(positive_feature_file, negative_feature_file):

  p_x = np.load(positive_feature_file)
  n_x = np.load(negative_feature_file)

  p_y = np.ones((len(p_x),), dtype=int)
  n_y = np.ones((len(n_x),), dtype=int) * -1

  x = np.concatenate((p_x, n_x), axis=0)
  y = np.concatenate((p_y, n_y), axis=0)

  x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
  logging.info("train shape:%s", x_train.shape)


  pca = PCA(n_components=512, whiten=True)
  pca = pca.fit(x)
  x_train = pca.transform(x_train)
  x_test = pca.transform(x_test)

  logging.info("train shape:%s", x_train.shape)


  # train
  svm_clf = svm.SVC(kernel='rbf', probability=True, decision_function_shape='ovr')
  svm_clf.set_params(C=0.4)
  svm_clf.fit(x_train, y_train)
  preds = svm_clf.predict(x_train)
  logging.info('train preds %d items, train accuracy:%.04f', len(preds), accuracy_score(preds, y_train))

  preds = svm_clf.predict(x_test)
  logging.info('test preds %d items, test accuracy:%.04f', len(preds), accuracy_score(preds, y_test))

  # joblib.dump(ss, './normal-ss.pkl')
  joblib.dump(pca, './normal-pca.pkl')
  joblib.dump(svm_clf, './normal-clf.pkl')

这个方法之所以有效是因为,迁移学习要求预训练网络与当前任务是相似的,那么最后一层网络的输出可以解释为特征的高度抽象,因此可以使用其作为特征进行分类。

这个方法虽然有效,但是需要准备两个数据集:正样本和负样本。很多时候,我们的任务是识别出我们关心的类别,这个类别我们可以花时间和精力来进行数据集的标注,但是对于我们不关心的类别的数据往往是不易收集的。那么,我们可以只准备一个数据集来训练一个只识别我们关心类别的模型吗?答案是可以的,使用One-class classification即可,一般翻译为异常检测或离群点检测。

如果你熟悉sklearn, 你可以使用svm.OneClassSVM:

  oc_svm_clf = svm.OneClassSVM(gamma=0.011, kernel='rbf', nu=0.08)
  oc_svm_clf.fit(x_train)

  preds = oc_svm_clf.predict(x_train)
  expects = np.ones((len(preds)), dtype=int)
  logging.info('train preds %d items, train accuracy:%.04f', len(preds), accuracy_score(preds, expects))

需要注意的是,One-class classification是一种无监督学习,从实验效果看,使用该方法筛选出来的图片「稳定性」相比前面两个方法稳定性要差。如果要在实际业务中使用该方法,需要仔细调整gamma参数,根据ROC曲线寻找一个相对理想的值。

小结

大多数场景下,受限于数据集、算力和时间限制,很少人是从零开始训练一个深度神经网络的。如果你的任务是解决工程中的某个特定问题,那么迁移学习可能是一个有效的高性价比解决方案。你可以使用通过添加或移除若干预训练网络层来实现迁移学习,也可以将预训练网络作为特征提取器,然后使用其他分类方法进行机器学习。迁移学习的效果往往不如完全训练整个网络的效果好,因此,你需要结合具体任务来权衡准确率和成本。

扩展阅读

What a May Day

周四晚上跟生日的父母吃过晚饭后,傻乎乎的带着小梦梦在孩子王逛,一点没有意识到已经是五月的最后一天。有时候,记录的习惯更多的就是提醒自己,时间的昼夜不舍。

这个月有意识的接触了挺多人,对于鄙人这种社交贫瘠的人来说,这个月花在这方面的时间算是奢侈的了。有老友也有新朋友。一次跟新朋友印象深刻的夜谈。已经很久没有跟新认识的人如此没有负担的沟通和交流了。对了,上一次有意思的聊天也是去年的这个时候。初夏,真是一个神奇的时节,一切都开始要变得明亮而耀眼。

五月在做和要做的事情越来越多,一种此情无计可消除,才下眉头,却上心头之感。一部分算是甜蜜的负担,而立之年,一些事情逐渐跟自己是否准备好已经没有必然关系,而是直接去解决它就对了。一部分是多过去时光的辜负,亡羊补牢,希望犹未晚矣。前天看韩老师的5X兴趣社区,看到李伟龙一个视频的幕后花絮,对话挺走心的:时间只会让你老去,其他什么都不会带来;只有你想改变的时候,你才能改变。

神奇的五月,居然达成了跑渣的第一次5公里(一个都不好意思提的配速)。从去年4月份参加跑团开始跑步,到现在已经一年多了,跑步成绩上没有任何提升,也是我预料之中的。对于这件事,我其实想得很明白:我一点都不喜欢运动,但是要支撑我的情怀和要做的事情,我必须要有这个练习和准备。显然,如果保持当前的做事的节奏,也许一周一次的跑步很快就无法支撑自己在做的事情,但是只要保持这件事情的惯性,我相信这股力量不会让自己失望。

这个月最喜欢的书是吴军老师的《智能时代》。因为一直在订阅吴军老师在得到的专栏,因此书中的很多内容其实都在专栏中听过了(如此说来,维护专栏及时高产如母猪,也是需要有存货当备份的?)。有两点体会最深:

  1. 人类文明发展是一个不短加速的过程,每一次加速都会让已有产业与新结束结合形成形成新的产业,赶上这个浪潮的会以数量级的优势领先,赶不上或者不愿拥抱变革的则会被无情的淘汰。
  2. 大数据和AI是当前最有可能成为下一个时代的蒸汽机和电。超越时代是困难的,但是从思维方式上则是可以刻意练习大数据和AI思维的。对于程序员而言,这尤为重要——有很大可能性,这决定了当前的你是成为为工业时代的码农,还是智能时代的工程师。

六月会迎来自己在两个月前设定的一个deadline, 从目前看来,不容乐观。可能当时在设定这个目标的时候,其实内心的真实独白就已经是法乎其中则得其下,法乎其上则得其中。但是,总的来说,过去的两个月无论是还在发生还是已经发生的事,多少带来了一丝丝改变。

期待六月,不负好时光。

微信小程序文件上传二三事

这段时间陆陆续续上了好几个微信小程序,功能上都会用到文件上传功能(头像上传、证件照上传等)。在APP上传文件到云端的正确姿势中,我们介绍了我们认为安全的上传流程:

即将密钥保存在服务器,客户端每次向服务器申请一个一次性的signature,然后使用该signature作为凭证来上传文件。一般情况下,向阿里云OSS上传内容,又拍云作为灾备。

随着大家安全意识的增强,这种上传流程几乎已经成为标准姿势。但是,把这个流程在应用到微信小程序却有很多细节需要调整。这里把踩过的坑记录一下,希望能让有需要的同学少走弯路。

微信小程序无法直接读取文件内容进行上传

在我们第一版的上传流程方案中,我们的cds 签名发放服务只实现了阿里云 PutObject 接口的signature发放. PutObject 上传是直接将需要上传的内容以二进流的方式 PUT 到云储存。

但是,微信小程序提供的文件上传API wx.uploadFile 要求文件通过 filePath 提供:

另一方面,微信小程序的 JS API 当前还比较封闭,无法根据 filePath 读取到文件内容,因此也无法通过 wx.request 直接发起网络请求的方式来实现文件上传。

考虑到 wx.uploadFile 本质上是一个 multipart/form-data 网络请求的封装,因此我们只需要实现一个与之对应的签名发放方式接口。阿里云OSS对应的上传接口是 PostObject, 又拍云对应的是其 FORM API. 以阿里云OSS为例,cds 服务生成signature 代码如下:

func GetDefaultOSSPolicyBase64Str(bucket, key string) string {
    policy := map[string]interface{}{
        "expiration": time.Now().AddDate(3, 0, 0).Format("2006-01-02T15:04:05.999Z"),
        "conditions": []interface{}{
            map[string]string{
                "bucket": bucket,
            },
            []string{"starts-with", "$key", key},
        },
    }
    data, _ := json.Marshal(&policy)
    return base64.StdEncoding.EncodeToString(data)
}

func GetOSSPostSignature(secret string, policyBase64 string) string {
    h := hmac.New(sha1.New, []byte(secret))
    io.WriteString(h, policyBase64)
    return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

小程序端代码如下:

//使用说明
/**
 * 1、引入该文件:const uploadFile = require('../../common/uploadAliyun.js');
 * 2、调用如下:
 * uploadImg: function () {
        const params = {
            _success: this._success
        }
        uploadFile.chooseImg(params);
    },
    _success: function(imgUrl){
        this.setData({
            cover_url: imgUrl,
        })
    },
*/

const uploadFile = {
    _fail: function(desc) {
        wx.showToast({
            icon: "none",
            title: desc
        })
    },
    _success: function() {},
    chooseImg: function(sendData) {
        //先存储传递过来的回调函数
        this._success = sendData._success;
        var that = this;
        wx.chooseImage({
            count: 1,
            sizeType: ['original', 'compressed'],
            sourceType: ["album", "carmera"],
            success: function (res) {
                that.getSign(res.tempFilePaths[0]);
            },
            fail: function (err) {
                wx.showToast({
                    icon: "none",
                    title: "选择图片失败" + err
                })
            }
        })
    },
    //获取阿里上传图片签名
    getSign: function (path) {
        var that = this;
        wx.request({
            url: 'https://somewhere/v2/cds/apply_upload_signature',
            method: 'POST',
            data: {
                "content_type": "image/jpeg",
                "signature_type": "oss_post",
                "business": "xiaochengxu",
                "file_ext": '.jpeg',
                "count": '1'
            },
            success: function (res) {
                let getData = res.data.data[0];
                that.startUpload(getData, path);
            },
            fail: function (err) {
                that._fail("获取签名失败" + JSON.stringify(err))
            }
        })
    },
    //拿到签名后开始上传
    startUpload: function (getData, path) {
        var that = this;
        this.uploadAliYun({
            filePath: path,
            dir: 'wxImg/',
            access_key_id: getData.oss_ext_param.access_key_id,
            policy_base64: getData.oss_ext_param.policy_base64,
            signature: getData.signature,
            upload_url: getData.upload_url,
            object_key: getData.oss_ext_param.object_key,
            content_url: getData.content_url.origin 
        })
    },
    uploadAliYun: function(params) {
        var that = this;
        // if (!params.filePath || params.filePath.length < 9) {
        if (!params.filePath) {
            wx.showModal({
                title: '图片错误',
                content: '请重试',
                showCancel: false,
            })
            return;
        }
        const aliyunFileKey = params.dir + params.filePath.replace('wxfile://', '');

        const aliyunServerURL = params.upload_url;
        const accessid = params.access_key_id;
        const policyBase64 = params.policy_base64;
        const signature = params.signature;
        wx.uploadFile({
            url: aliyunServerURL,
            filePath: params.filePath,
            name: 'file',
            formData: {
                'key': params.object_key,
                'policy': policyBase64,
                'OSSAccessKeyId': accessid,
                'Signature': signature
            },
            success: function (res) {
                if (res.statusCode != 204) {
                    that._fail("上传图片失败");
                    return;
                }
                that._success(params.content_url);
            },
            fail: function (err) {
                that._fail(JSON.stringify(err));
            },
        })
    }
}


module.exports = uploadFile;

使用阿里云OSS域名上传失败

解决签名问题后,发现使用阿里云OSS提供的上传域名无法上传成功,在微信后台尝试添加合法域名的时候,惊奇的发现阿里云OSS的域名直接被微信小程序封禁了:

显然是两个神仙在打架,作为草民只能见招拆招。解决办法就是在阿里云OSS -> bucket -> 域名管理 绑定用户域名:

此外,由于微信小程序已经升级为uploadFile的链接必须是https, 因此还需要在绑定用户域名后设置 证书托管

他山之石,可以攻玉

既然微信能够封禁用阿里云OSS的上传域名,那么微信也可以封禁你自定义的域名。根据以往经验(对天发誓,我们不是有意为之,我们也是受害者……),微信封禁域名一般都是一锅端,即发现一个子域名存在违规内容,那么整个域名都会被封禁。因此,一方面要从技术角度对上传的内容及时检查是否合规(如黄图扫描),另一方面提前做好域名规划,将业务接口域名与自定义的文件上传域名分开,这样即使上传域名被一锅端了,不至于是业务完全不可用。