weight_init namespace
2015-12-03 10:55:18 0 举报
AI智能生成
weight_init namespace
作者其他创作
大纲/内容
lecun:scalable
lecun() : scalable((float_t)1.0) {}
virtual lecun* clone() const { return new lecun(scale_); }
explicit lecun(float_t value) : scalable(value) {}
void fill(vec_t *weight, layer_size_t fan_in, layer_size_t fan_out) {
CNN_UNREFERENCED_PARAMETER(fan_out);
const float_t weight_base = scale_ / std::sqrt(fan_in);
uniform_rand(weight->begin(), weight->end(), -weight_base, weight_base);
}
function
virtual void fill(vec_t *weight, layer_size_t fan_in, layer_size_t fan_out) = 0
virtual function* clone() const = 0;
scalable:function
scalable(float_t value) : scale_(value) {}
void scale(float_t value) {scale_ = value;}
protected:
float_t scale_;
xavier:scalable
xavier() : scalable((float_t)6.0) {}
virtual xavier* clone() const { return new xavier(scale_); }
explicit xavier(float_t value) : scalable(value) {}
void fill(vec_t *weight, layer_size_t fan_in, layer_size_t fan_out) {
const float_t weight_base = std::sqrt(scale_ / (fan_in + fan_out));
uniform_rand(weight->begin(), weight->end(), -weight_base, weight_base);
}
constant:scalable
constant() : scalable((float_t)0.0) {}
explicit constant(float_t value) : scalable(value) {}
void fill(vec_t *weight, layer_size_t fan_in, layer_size_t fan_out) {
CNN_UNREFERENCED_PARAMETER(fan_in);
CNN_UNREFERENCED_PARAMETER(fan_out);
std::fill(weight->begin(), weight->end(), scale_);
}
virtual constant* clone() const { return new constant(scale_); }
0 条评论
下一页