Skip to content

Commit

Permalink
Merge pull request #276 from FluxML/double-trouble
Browse files Browse the repository at this point in the history
Force `Float32` as type presented to Flux chains
  • Loading branch information
ablaom authored Sep 29, 2024
2 parents 945016d + bc805e4 commit 19d4275
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 68 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ julia = "1.9"
[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -42,4 +43,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]
test = ["CUDA", "cuDNN", "LinearAlgebra", "Logging", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]
15 changes: 11 additions & 4 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,22 @@ input `X` and target `y` in the form required by
by `model.batch_size`.)
"""
function collate(model, X, y)
function collate(model, X, y, verbosity)
row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size)
Xmatrix = reformat(X)
Xmatrix = _f32(reformat(X), verbosity)
ymatrix = reformat(y)
return [_get(Xmatrix, b) for b in row_batches], [_get(ymatrix, b) for b in row_batches]
end
function collate(model::NeuralNetworkBinaryClassifier, X, y)
function collate(model::NeuralNetworkBinaryClassifier, X, y, verbosity)
row_batches = Base.Iterators.partition(1:nrows(y), model.batch_size)
Xmatrix = reformat(X)
Xmatrix = _f32(reformat(X), verbosity)
yvec = (y .== classes(y)[2])' # convert to boolean
return [_get(Xmatrix, b) for b in row_batches], [_get(yvec, b) for b in row_batches]
end

_f32(x::AbstractArray{Float32}, verbosity) = x
function _f32(x::AbstractArray, verbosity)
verbosity > 0 && @info "MLJFlux: converting input data to Float32"
return Float32.(x)
end

8 changes: 4 additions & 4 deletions src/encoders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ function ordinal_encoder_fit(X; featinds)
feat_col = Tables.getcolumn(Tables.columns(X), i)
feat_levels = levels(feat_col)
# Check if feat levels is already ordinal encoded in which case we skip
(Set([float(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue
(Set([Float32(i) for i in 1:length(feat_levels)]) == Set(feat_levels)) && continue
# Compute the dict using the given feature_mapper function
mapping_matrix[i] =
Dict{Any, AbstractFloat}(
value => float(index) for (index, value) in enumerate(feat_levels)
Dict{eltype(feat_levels), Float32}(
value => Float32(index) for (index, value) in enumerate(feat_levels)
)
end
return mapping_matrix
Expand Down Expand Up @@ -67,7 +67,7 @@ function ordinal_encoder_transform(X, mapping_matrix)
test_levels = levels(col)
check_unkown_levels(train_levels, test_levels)
level2scalar = mapping_matrix[ind]
new_col = recode(col, level2scalar...)
new_col = recode(unwrap.(col), level2scalar...)
push!(new_feats, new_col)
else
push!(new_feats, col)
Expand Down
3 changes: 1 addition & 2 deletions src/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@ julia> output = embedder(batch)
```
""" # 1. Define layer struct to hold parameters
struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer}

embedders::A1
modifiers::A2 # applied on the input before passing it to the embedder
numfeats::I
end

# 2. Define the forward pass (i.e., calling an instance of the layer)
(m::EntityEmbedder)(x) =
vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)
(vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...))

# 3. Define the constructor which initializes the parameters and returns the instance
function EntityEmbedder(entityprops, numfeats; init = Flux.randn32)
Expand Down
24 changes: 12 additions & 12 deletions src/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,35 +66,35 @@ function MLJModelInterface.fit(model::MLJFluxModel,
X,
y)
# GPU and rng related variables
move = Mover(model.acceleration)
move = MLJFlux.Mover(model.acceleration)
rng = true_rng(model)

# Get input properties
shape = MLJFlux.shape(model, X, y)
cat_inds = get_cat_inds(X)
cat_inds = MLJFlux.get_cat_inds(X)
pure_continuous_input = isempty(cat_inds)

# Decide whether to enable entity embeddings (e.g., ImageClassifier won't)
enable_entity_embs = is_embedding_enabled(model) && !pure_continuous_input
enable_entity_embs = MLJFlux.is_embedding_enabled(model) && !pure_continuous_input

# Prepare entity embeddings inputs and encode X if entity embeddings enabled
featnames = []
if enable_entity_embs
X = convert_to_table(X)
X = MLJFlux.convert_to_table(X)
featnames = Tables.schema(X).names
end

# entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i])
# entityprops is (index = cat_inds[i], levels = num_levels[i], newdim = newdims[i])
# for each categorical feature
default_embedding_dims = enable_entity_embs ? model.embedding_dims : Dict{Symbol, Real}()
entityprops, entityemb_output_dim =
prepare_entityembs(X, featnames, cat_inds, default_embedding_dims)
X, ordinal_mappings = ordinal_encoder_fit_transform(X; featinds = cat_inds)
MLJFlux.prepare_entityembs(X, featnames, cat_inds, default_embedding_dims)
X, ordinal_mappings = MLJFlux.ordinal_encoder_fit_transform(X; featinds = cat_inds)

## Construct model chain
chain =
(!enable_entity_embs) ? construct_model_chain(model, rng, shape, move) :
construct_model_chain_with_entityembs(
MLJFlux.construct_model_chain_with_entityembs(
model,
rng,
shape,
Expand All @@ -103,8 +103,8 @@ function MLJModelInterface.fit(model::MLJFluxModel,
entityemb_output_dim,
)

# Format data as needed by Flux and move to GPU
data = move.(collate(model, X, y))
# Format data as needed by Flux and move to GPU
data = move.(MLJFlux.collate(model, X, y, verbosity))

# Test chain works (as it may be custom)
x = data[1][1]
Expand Down Expand Up @@ -143,7 +143,7 @@ function MLJModelInterface.fit(model::MLJFluxModel,
featnames,
)

# Prepare fitresult
# Prepare fitresult
fitresult =
MLJFlux.fitresult(model, Flux.cpu(chain), y, ordinal_mappings, embedding_matrices)

Expand Down Expand Up @@ -216,7 +216,7 @@ function MLJModelInterface.update(model::MLJFluxModel,
chain = construct_model_chain(model, rng, shape, move)
end
# reset `optimiser_state`:
data = move.(collate(model, X, y))
data = move.(collate(model, X, y, verbosity))
regularized_optimiser, optimiser_state =
prepare_optimiser(data, model, chain)
epochs = model.epochs
Expand Down
14 changes: 7 additions & 7 deletions test/classifier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@ seed!(1234)
N = 300
Xm = MLJBase.table(randn(Float32, N, 5)); # purely numeric
X = (; Tables.columntable(Xm)...,
Column1 = repeat([1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column1 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column2 = categorical(repeat(['a', 'b', 'c', 'd', 'e'], Int(N / 5))),
Column3 = categorical(repeat(["b", "c", "d", "f", "f"], Int(N / 5)), ordered = true),
Column4 = repeat([1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column5 = randn(N),
Column4 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column5 = randn(Float32, N),
Column6 = categorical(
repeat(["group1", "group1", "group2", "group2", "group3"], Int(N / 5)),
),
)


ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N)
ycont = 2 * X.x1 - X.x3 + 0.1 * rand(Float32, N)
m, M = minimum(ycont), maximum(ycont)
_, a, b, _ = range(m, stop = M, length = 4) |> collect
y = map(ycont) do η
Expand Down Expand Up @@ -111,7 +110,8 @@ end

# check different resources (CPU1, CUDALibs, etc)) give about the same loss:
reference = losses[1]
@test all(x -> abs(x - reference) / reference < 1e-4, losses[2:end])
println("losses for each resource: $losses")
@test all(x -> abs(x - reference) / reference < 0.03, losses[2:end])


# # NEURAL NETWORK BINARY CLASSIFIER
Expand All @@ -126,7 +126,7 @@ end
seed!(1234)
N = 300
X = MLJBase.table(rand(Float32, N, 4));
ycont = 2 * X.x1 - X.x3 + 0.1 * rand(N)
ycont = Float32.(2 * X.x1 - X.x3 + 0.1 * rand(N))
m, M = minimum(ycont), maximum(ycont)
_, a, _ = range(m, stop = M, length = 3) |> collect
y = map(ycont) do η
Expand Down
19 changes: 12 additions & 7 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,28 @@ rowvec(y::Vector) = reshape(y, 1, length(y))
end

@testset "collate" begin
# NeuralNetworRegressor:
Xmatrix = broadcast(x->round(x, sigdigits=2), rand(stable_rng, 10, 3))
Xmatrix = broadcast(x->round(x, sigdigits=2), rand(stable_rng, Float32, 10, 3))
Xmat_f64 = Float64.(Xmatrix)
# convert to a column table:
X = MLJBase.table(Xmatrix)
X_64 = MLJBase.table(Xmat_f64)

# NeuralNetworRegressor:
y = rand(stable_rng, Float32, 10)
model = MLJFlux.NeuralNetworkRegressor()
model.batch_size= 3
@test MLJFlux.collate(model, X, y) ==
@test MLJFlux.collate(model, X, y, 1) == MLJFlux.collate(model, X_64, y, 1) ==
([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
rowvec.([y[1:3], y[4:6], y[7:9], y[10:10]]))
@test_logs (:info,) MLJFlux.collate(model, X_64, y, 1)
@test_logs min_level=Logging.Info MLJFlux.collate(model, X, y, 1)
@test_logs min_level=Logging.Info MLJFlux.collate(model, X, y, 0)

# NeuralNetworClassifier:
y = categorical(['a', 'b', 'a', 'a', 'b', 'a', 'a', 'a', 'b', 'a'])
model = MLJFlux.NeuralNetworkClassifier()
model.batch_size = 3
data = MLJFlux.collate(model, X, y)
data = MLJFlux.collate(model, X, y, 1)

@test data == ([Xmatrix'[:,1:3], Xmatrix'[:,4:6],
Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
Expand All @@ -42,13 +47,13 @@ end
y = MLJBase.table(ymatrix) # a rowaccess table
model = MLJFlux.NeuralNetworkRegressor()
model.batch_size= 3
@test MLJFlux.collate(model, X, y) ==
@test MLJFlux.collate(model, X, y, 1) ==
([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
rowvec.([ymatrix'[:,1:3], ymatrix'[:,4:6], ymatrix'[:,7:9],
ymatrix'[:,10:10]]))

y = Tables.columntable(y) # try a columnaccess table
@test MLJFlux.collate(model, X, y) ==
@test MLJFlux.collate(model, X, y, 1) ==
([Xmatrix'[:,1:3], Xmatrix'[:,4:6], Xmatrix'[:,7:9], Xmatrix'[:,10:10]],
rowvec.([ymatrix'[:,1:3], ymatrix'[:,4:6],
ymatrix'[:,7:9], ymatrix'[:,10:10]]))
Expand All @@ -58,7 +63,7 @@ end
y = categorical(['a', 'b', 'a', 'a', 'b', 'a', 'a', 'a', 'b', 'a'])
model = MLJFlux.ImageClassifier(batch_size=2)

data = MLJFlux.collate(model, Xmatrix, y)
data = MLJFlux.collate(model, Xmatrix, y, 1)
@test first.(data) == (Float32.(cat(Xmatrix[1], Xmatrix[2], dims=4)),
rowvec.([1 0;0 1]))

Expand Down
4 changes: 2 additions & 2 deletions test/encoders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
@test map[2] == Dict('a' => 1, 'b' => 2, 'c' => 3, 'd' => 4, 'e' => 5)
@test map[3] == Dict("b" => 1, "c" => 2, "d" => 3)
@test Xenc.Column1 == [1.0, 2.0, 3.0, 4.0, 5.0]
@test Xenc.Column2 == [1.0, 2.0, 3.0, 4.0, 5.0]
@test Xenc.Column3 == [1, 2, 3]
@test Xenc.Column2 == Float32.([1.0, 2.0, 3.0, 4.0, 5.0])
@test Xenc.Column3 == Float32.([1, 2, 3])
@test Xenc.Column4 == [1.0, 2.0, 3.0, 4.0, 5.0]

X = coerce(X, :Column1 => Multiclass)
Expand Down
15 changes: 8 additions & 7 deletions test/entity_embedding.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""
See more functional tests in entity_embedding_utils.jl and mlj_model_interface.jl
"""

batch = [
batch = Float32.([
0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1;
1 2 3 4 5 6 7 8 9 10;
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1
1 1 2 2 1 1 2 2 1 1
]
1 2 3 4 5 6 7 8 9 10;
0.9 0.1 0.4 0.5 0.3 0.7 0.8 0.9 1.0 1.1;
1 1 2 2 1 1 2 2 1 1
])


entityprops = [
(index = 2, levels = 10, newdim = 2),
Expand Down Expand Up @@ -145,7 +145,8 @@ end
numfeats = 4
embedder = MLJFlux.EntityEmbedder(entityprops, 4)
output = embedder(batch)
@test output == batch
@test output batch
@test eltype(output) == Float32
end


Expand Down
Loading

0 comments on commit 19d4275

Please sign in to comment.