API Documentation

LighthouseFlux.FluxClassifierType
FluxClassifier(model, optimiser, classes; params=Flux.params(model),
               onehot=(label -> Flux.onehot(label, 1:length(classes))),
               onecold=(label -> Flux.onecold(label, 1:length(classes))))

Return a FluxClassifier <: Lighthouse.AbstractClassifier with the given arguments:

  • model: a Flux model. The model must additionally support LighthouseFlux's loss and loss_and_prediction functions.

  • optimiser: a Flux optimiser

  • classes: a Vector or Tuple of possible class values; this is the return

value of Lighthouse.classes(::FluxClassifier).

  • params: The parameters to optimise during training; generally, a Zygote.Params

value or a value that can be passed to Zygote.Params.

  • onehot: the function used to convert hard labels to soft labels when

Lighthouse.onehot is called with this classifier.

  • onecold: the function used to convert soft labels to hard labels when

Lighthouse.onecold is called with this classifier.

source
LighthouseFlux.lossFunction
loss(model, batch_arguments...)

Return the scalar loss of model given batch_arguments.

This method must be implemented for all models passed to FluxClassifier.

source
LighthouseFlux.loss_and_predictionFunction
loss_and_prediction(model, input_batch, other_batch_arguments...)

Return (model_loss, model_prediction) where:

  • model_loss is equivalent to (and defaults to) loss(model, input_batch, other_batch_arguments...).

  • model_prediction is a matrix where the ith column is the soft label prediction for the ith

sample in input_batch. Thus, the numnber of columns should be size(input_batch)[end], while the number of rows is equal to the number of possible classes predicted by model. model_prediction defaults to model(input_batch).

This method must be implemented for all models passed to FluxClassifier, but has the default return values described above, so it only needs to be overloaded if the default definitions do not yield the expected values for a given model type. It additionally may be overloaded to avoid redundant computation if model's loss function computes soft labels as an intermediate result.

source

Internal functions

LighthouseFlux.fforeach_pairsFunction
fforeach_pairs(F, x, keys=(); exclude=Functors.isleaf, cache=IdDict(),
               prune=Functors.NoKeyword(), combine=(ks, k) -> (ks..., k))

Walks the Functors.jl-compatible graph x (by calling pairs ∘ Functors.children), applying F(parent_key, child) at each step along the way. Here parent_key is the key part of a key-value pair returned from pairs ∘ Functors.children, combined with the previous parent_key by combine.

Example

julia> using Functors, LighthouseFlux

julia> struct Foo; x; y; end

julia> @functor Foo

julia> struct Bar; x; end

julia> @functor Bar

julia> m = Foo(Bar([1,2,3]), (4, 5, Bar(Foo(6, 7))));

julia> LighthouseFlux.fforeach_pairs((k,v) -> @show((k, v)), m)
(k, v) = ((:x,), Bar([1, 2, 3]))
(k, v) = ((:x, :x), [1, 2, 3])
(k, v) = ((:y,), (4, 5, Bar(Foo(6, 7))))
(k, v) = ((:y, 1), 4)
(k, v) = ((:y, 2), 5)
(k, v) = ((:y, 3), Bar(Foo(6, 7)))
(k, v) = ((:y, 3, :x), Foo(6, 7))
(k, v) = ((:y, 3, :x, :x), 6)
(k, v) = ((:y, 3, :x, :y), 7)

The combine argument can be used to customize how the keys are combined. For example

julia> LighthouseFlux.fforeach_pairs((k,v) -> @show((k, v)), m, ""; combine=(ks, k) -> string(ks, "/", k))
(k, v) = ("/x", Bar([1, 2, 3]))
(k, v) = ("/x/x", [1, 2, 3])
(k, v) = ("/y", (4, 5, Bar(Foo(6, 7))))
(k, v) = ("/y/1", 4)
(k, v) = ("/y/2", 5)
(k, v) = ("/y/3", Bar(Foo(6, 7)))
(k, v) = ("/y/3/x", Foo(6, 7))
(k, v) = ("/y/3/x/x", 6)
(k, v) = ("/y/3/x/y", 7)
source