Scikit-Learn接口包装器
我们可以通过包装器将Sequential模型(仅有一个输入)作为Scikit-Learn工作流的一部分,相关的包装器定义在keras.wrappers.scikit_learn.py中
目前,有两个包装器可用:
keras.wrappers.scikit_learn.KerasClassifier(build_fn=None, **sk_params)实现了sklearn的分类器接口
keras.wrappers.scikit_learn.KerasRegressor(build_fn=None, **sk_params)实现了sklearn的回归器接口
参数
-
build_fn:可调用的函数或类对象
-
sk_params:模型参数和训练参数
build_fn应构造、编译并返回一个Keras模型,该模型将稍后用于训练/测试。build_fn的值可能为下列三种之一:
-
一个函数
-
一个具有
call方法的类对象 -
None,代表你的类继承自
KerasClassifier或KerasRegressor,其call方法为其父类的call方法
sk_params以模型参数和训练(超)参数作为参数。合法的模型参数为build_fn的参数。注意,‘build_fn’应提供其参数的默认值。所以我们不传递任何值给sk_params也可以创建一个分类器/回归器
sk_params还接受用于调用fit,predict,predict_proba和score方法的参数,如nb_epoch,batch_size等。这些用于训练或预测的参数按如下顺序选择:
-
传递给
fit,predict,predict_proba和score的字典参数 -
传递个
sk_params的参数 -
keras.models.Sequential,fit,predict,predict_proba和score的默认值
当使用scikit-learn的grid_search接口时,合法的可转换参数是你可以传递给sk_params的参数,包括训练参数。即,你可以使用grid_search来搜索最佳的batch_size或nb_epoch以及其他模型参数