rec/fdeep/filter.hpp

101 lines
2.6 KiB
C++
Raw Normal View History

2020-03-18 06:42:46 +00:00
// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
// https://opensource.org/licenses/MIT)
#pragma once
#include "fdeep/common.hpp"
#include "fdeep/tensor5.hpp"
#include "fdeep/shape5.hpp"
#include <cassert>
#include <cstddef>
#include <vector>
namespace fdeep { namespace internal
{
class filter
{
public:
filter(const tensor5& m, float_type bias) : m_(m), bias_(bias)
{
}
const shape5& shape() const
{
return m_.shape();
}
std::size_t volume() const
{
return m_.shape().volume();
}
const tensor5& get_tensor5() const
{
return m_;
}
float_type get(std::size_t y, size_t x, std::size_t z) const
{
return m_.get(0, 0, y, x, z);
}
float_type get_bias() const
{
return bias_;
}
void set_params(const float_vec& weights, float_type bias)
{
assertion(weights.size() == m_.shape().volume(),
"invalid parameter count");
m_ = tensor5(m_.shape(), float_vec(weights));
bias_ = bias;
}
private:
tensor5 m_;
float_type bias_;
};
typedef std::vector<filter> filter_vec;
inline filter dilate_filter(const shape2& dilation_rate, const filter& undilated)
{
return filter(dilate_tensor5(dilation_rate, undilated.get_tensor5()),
undilated.get_bias());
}
inline filter_vec generate_filters(
const shape2& dilation_rate,
const shape5& filter_shape, std::size_t k,
const float_vec& weights, const float_vec& bias)
{
filter_vec filters(k, filter(tensor5(filter_shape, 0), 0));
assertion(!filters.empty(), "at least one filter needed");
const std::size_t param_count = fplus::sum(fplus::transform(
fplus_c_mem_fn_t(filter, volume, std::size_t), filters));
assertion(static_cast<std::size_t>(weights.size()) == param_count,
"invalid weight size");
const auto filter_param_cnt = filters.front().shape().volume();
auto filter_weights =
fplus::split_every(filter_param_cnt, weights);
assertion(filter_weights.size() == filters.size(),
"invalid size of filter weights");
assertion(bias.size() == filters.size(), "invalid bias size");
auto it_filter_val = std::begin(filter_weights);
auto it_filter_bias = std::begin(bias);
for (auto& filt : filters)
{
filt.set_params(*it_filter_val, *it_filter_bias);
filt = dilate_filter(dilation_rate, filt);
++it_filter_val;
++it_filter_bias;
}
return filters;
}
} } // namespace fdeep, namespace internal