117 lines
2.9 KiB
C++
Executable File
117 lines
2.9 KiB
C++
Executable File
// Copyright 2016, Tobias Hermann.
|
|
// https://github.com/Dobiasd/frugally-deep
|
|
// Distributed under the MIT License.
|
|
// (See accompanying LICENSE file or at
|
|
// https://opensource.org/licenses/MIT)
|
|
|
|
#pragma once
|
|
|
|
#include "fdeep/common.hpp"
|
|
|
|
#include "fdeep/tensor5.hpp"
|
|
|
|
#include "fdeep/node.hpp"
|
|
|
|
#include <cstddef>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace fdeep { namespace internal
|
|
{
|
|
|
|
class layer;
|
|
typedef std::shared_ptr<layer> layer_ptr;
|
|
typedef std::vector<layer_ptr> layer_ptrs;
|
|
|
|
class activation_layer;
|
|
typedef std::shared_ptr<activation_layer> activation_layer_ptr;
|
|
tensor5s apply_activation_layer(const activation_layer_ptr& ptr,
|
|
const tensor5s& input);
|
|
|
|
class layer
|
|
{
|
|
public:
|
|
explicit layer(const std::string& name)
|
|
: name_(name), nodes_(), activation_(nullptr)
|
|
{
|
|
}
|
|
virtual ~layer()
|
|
{
|
|
}
|
|
|
|
void set_activation(const activation_layer_ptr& activation)
|
|
{
|
|
activation_ = activation;
|
|
}
|
|
|
|
void set_nodes(const nodes& layer_nodes)
|
|
{
|
|
nodes_ = layer_nodes;
|
|
}
|
|
|
|
virtual tensor5s apply(const tensor5s& input) const final
|
|
{
|
|
const auto result = apply_impl(input);
|
|
if (activation_ == nullptr)
|
|
return result;
|
|
else
|
|
return apply_activation_layer(activation_, result);
|
|
}
|
|
|
|
virtual tensor5 get_output(const layer_ptrs& layers,
|
|
output_dict& output_cache,
|
|
std::size_t node_idx, std::size_t tensor_idx) const
|
|
{
|
|
const node_connection conn(name_, node_idx, tensor_idx);
|
|
|
|
if (!fplus::map_contains(output_cache, conn.without_tensor_idx()))
|
|
{
|
|
assertion(node_idx < nodes_.size(), "invalid node index");
|
|
output_cache[conn.without_tensor_idx()] =
|
|
nodes_[node_idx].get_output(layers, output_cache, *this);
|
|
}
|
|
|
|
const auto& outputs = fplus::get_from_map_unsafe(
|
|
output_cache, conn.without_tensor_idx());
|
|
|
|
assertion(tensor_idx < outputs.size(),
|
|
"invalid tensor index");
|
|
return outputs[tensor_idx];
|
|
}
|
|
|
|
std::string name_;
|
|
nodes nodes_;
|
|
|
|
protected:
|
|
virtual tensor5s apply_impl(const tensor5s& input) const = 0;
|
|
activation_layer_ptr activation_;
|
|
};
|
|
|
|
inline tensor5 get_layer_output(const layer_ptrs& layers,
|
|
output_dict& output_cache,
|
|
const layer_ptr& layer,
|
|
std::size_t node_idx, std::size_t tensor_idx)
|
|
{
|
|
return layer->get_output(layers, output_cache, node_idx, tensor_idx);
|
|
}
|
|
|
|
inline tensor5s apply_layer(const layer& layer, const tensor5s& inputs)
|
|
{
|
|
return layer.apply(inputs);
|
|
}
|
|
|
|
inline layer_ptr get_layer(const layer_ptrs& layers,
|
|
const std::string& layer_id)
|
|
{
|
|
const auto is_matching_layer = [layer_id](const layer_ptr& ptr) -> bool
|
|
{
|
|
return ptr->name_ == layer_id;
|
|
};
|
|
return fplus::throw_on_nothing(
|
|
error("dangling layer reference: " + layer_id),
|
|
fplus::find_first_by(is_matching_layer, layers));
|
|
}
|
|
|
|
} } // namespace fdeep, namespace internal
|