rec/fdeep/layers/batch_normalization_layer.hpp
2020-03-18 14:42:46 +08:00

92 lines
2.7 KiB
C++
Executable File

// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
// https://opensource.org/licenses/MIT)
#pragma once
#include "fdeep/layers/layer.hpp"
#include <string>
namespace fdeep { namespace internal
{
// https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html
// https://stackoverflow.com/a/46444452/1866775
class batch_normalization_layer : public layer
{
public:
explicit batch_normalization_layer(const std::string& name,
const float_vec& moving_mean,
const float_vec& moving_variance,
const float_vec& beta,
const float_vec& gamma,
float_type epsilon)
: layer(name),
moving_mean_(moving_mean),
moving_variance_(moving_variance),
beta_(beta),
gamma_(gamma),
epsilon_(epsilon)
{
}
protected:
float_vec moving_mean_;
float_vec moving_variance_;
float_vec beta_;
float_vec gamma_;
float_type epsilon_;
tensor5 apply_to_slices(const tensor5& input) const
{
assertion(moving_mean_.size() == input.shape().depth_,
"invalid beta");
assertion(moving_variance_.size() == input.shape().depth_,
"invalid beta");
const bool use_gamma = !gamma_.empty();
if (use_gamma)
{
assertion(gamma_.size() == input.shape().depth_, "invalid gamma");
}
const bool use_beta = !beta_.empty();
if (use_beta)
{
assertion(beta_.size() == input.shape().depth_, "invalid beta");
}
tensor5 output(input.shape(), 0);
for (std::size_t z = 0; z < output.shape().depth_; ++z)
{
const float_type denom = std::sqrt(moving_variance_[z] + epsilon_);
for (std::size_t y = 0; y < output.shape().height_; ++y)
{
for (std::size_t x = 0; x < output.shape().width_; ++x)
{
float_type val = input.get(0, 0, y, x, z);
val -= moving_mean_[z];
if (use_gamma)
val *= gamma_[z];
val /= denom;
if (use_beta)
val += beta_[z];
output.set(0, 0, y, x, z, val);
}
}
}
return output;
}
tensor5s apply_impl(const tensor5s& inputs) const override
{
assertion(inputs.size() == 1, "invalid number of tensors");
const auto& input = inputs.front();
return {apply_to_slices(input)};
}
};
} } // namespace fdeep, namespace internal