API Documentation
LighthouseFlux.FluxClassifier
— TypeFluxClassifier(model, optimiser, classes; params=Flux.params(model),
onehot=(label -> Flux.onehot(label, 1:length(classes))),
onecold=(label -> Flux.onecold(label, 1:length(classes))))
Return a FluxClassifier <: Lighthouse.AbstractClassifier
with the given arguments:
model
: a Flux model. The model must additionally support LighthouseFlux'sloss
andloss_and_prediction
functions.optimiser
: a Flux optimiserclasses
: aVector
orTuple
of possible class values; this is the return
value of Lighthouse.classes(::FluxClassifier)
.
params
: The parameters to optimise during training; generally, aZygote.Params
value or a value that can be passed to Zygote.Params
.
onehot
: the function used to convert hard labels to soft labels when
Lighthouse.onehot
is called with this classifier.
onecold
: the function used to convert soft labels to hard labels when
Lighthouse.onecold
is called with this classifier.
LighthouseFlux.loss
— Functionloss(model, batch_arguments...)
Return the scalar loss of model
given batch_arguments
.
This method must be implemented for all model
s passed to FluxClassifier
.
LighthouseFlux.loss_and_prediction
— Functionloss_and_prediction(model, input_batch, other_batch_arguments...)
Return (model_loss, model_prediction)
where:
model_loss
is equivalent to (and defaults to)loss(model, input_batch, other_batch_arguments...)
.model_prediction
is a matrix where thei
th column is the soft label prediction for thei
th
sample in input_batch
. Thus, the numnber of columns should be size(input_batch)[end]
, while the number of rows is equal to the number of possible classes predicted by model. model_prediction
defaults to model(input_batch)
.
This method must be implemented for all model
s passed to FluxClassifier
, but has the default return values described above, so it only needs to be overloaded if the default definitions do not yield the expected values for a given model
type. It additionally may be overloaded to avoid redundant computation if model
's loss function computes soft labels as an intermediate result.
LighthouseFlux.evaluate_chain_in_debug_mode
— Functionevaluate_chain_in_debug_mode(chain::Flux.Chain, input)
Evaluate chain(input)
, printing additional debug information at each layer.
Internal functions
LighthouseFlux.gather_weights_gradients
— Functiongather_weights_gradients(classifier, gradients)
Collects the weights and gradients from classifier
into a Dict
.
LighthouseFlux.fforeach_pairs
— Functionfforeach_pairs(F, x, keys=(); exclude=Functors.isleaf, cache=IdDict(),
prune=Functors.NoKeyword(), combine=(ks, k) -> (ks..., k))
Walks the Functors.jl-compatible graph x
(by calling pairs ∘ Functors.children
), applying F(parent_key, child)
at each step along the way. Here parent_key
is the key
part of a key-value pair returned from pairs ∘ Functors.children
, combined with the previous parent_key
by combine
.
Example
julia> using Functors, LighthouseFlux
julia> struct Foo; x; y; end
julia> @functor Foo
julia> struct Bar; x; end
julia> @functor Bar
julia> m = Foo(Bar([1,2,3]), (4, 5, Bar(Foo(6, 7))));
julia> LighthouseFlux.fforeach_pairs((k,v) -> @show((k, v)), m)
(k, v) = ((:x,), Bar([1, 2, 3]))
(k, v) = ((:x, :x), [1, 2, 3])
(k, v) = ((:y,), (4, 5, Bar(Foo(6, 7))))
(k, v) = ((:y, 1), 4)
(k, v) = ((:y, 2), 5)
(k, v) = ((:y, 3), Bar(Foo(6, 7)))
(k, v) = ((:y, 3, :x), Foo(6, 7))
(k, v) = ((:y, 3, :x, :x), 6)
(k, v) = ((:y, 3, :x, :y), 7)
The combine
argument can be used to customize how the keys are combined. For example
julia> LighthouseFlux.fforeach_pairs((k,v) -> @show((k, v)), m, ""; combine=(ks, k) -> string(ks, "/", k))
(k, v) = ("/x", Bar([1, 2, 3]))
(k, v) = ("/x/x", [1, 2, 3])
(k, v) = ("/y", (4, 5, Bar(Foo(6, 7))))
(k, v) = ("/y/1", 4)
(k, v) = ("/y/2", 5)
(k, v) = ("/y/3", Bar(Foo(6, 7)))
(k, v) = ("/y/3/x", Foo(6, 7))
(k, v) = ("/y/3/x/x", 6)
(k, v) = ("/y/3/x/y", 7)