Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1a6e7c9
Set up Julia package and example SSP2 optimizations
lxvm Mar 13, 2026
7a0353b
Delete examples/julia/Manifest.toml
lxvm Mar 14, 2026
422bda4
fix underflow in tanh projection
lxvm Mar 14, 2026
34b8986
use evalpoly
lxvm Mar 14, 2026
fbe944e
fix projection typos/bugs
lxvm Mar 16, 2026
e4e53df
refactor projection into a loop over target points
lxvm Mar 16, 2026
46105ab
refactor interpolation into another module
lxvm Mar 16, 2026
0e66fba
update example to use a higher resolution target grid and reduce allo…
lxvm Mar 16, 2026
a5ddab8
correct the usage of cubic_adjoint deriv api
lxvm Mar 17, 2026
639ed5d
correct the comparison to finite differences because arrays are reused
lxvm Mar 17, 2026
b19bd89
avoid NaN in tanh projection when x == eta at beta = Inf
lxvm Mar 19, 2026
e9f1442
change definition of ProjectionProblem and SSP2
lxvm Mar 20, 2026
000599d
add grid as an optional argument to PaddingProblem
lxvm Mar 20, 2026
db22074
add init! to interface
lxvm Mar 20, 2026
326eeb0
fix typo
lxvm Mar 20, 2026
1562411
add support for SSP1
lxvm Mar 21, 2026
79588e5
add linear interpolation and use bc in cubic interpolation
lxvm Mar 21, 2026
75c442a
rename R_smoothing_factor to smoothing_radius like python api
lxvm Mar 21, 2026
17788f2
add pythonic API
lxvm Mar 21, 2026
f32ed89
add ChainRulesCore rrules for pythonic api
lxvm Mar 21, 2026
7502781
run checks to make sure different apis match
lxvm Mar 22, 2026
7a7ac00
document apis of padding module
lxvm Mar 25, 2026
242c131
document public api in kernels module
lxvm Mar 25, 2026
e4525d8
document public api of convolution module
lxvm Mar 25, 2026
3cfa7d9
Document public api of interpolation module
lxvm Mar 25, 2026
13ac5ea
document the public api of the projection module
lxvm Mar 25, 2026
d10ba70
document public api of ssp package
lxvm Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/julia/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SSP = "e5b5d2ee-15bb-40cc-a0da-b305b842b7a8"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
201 changes: 201 additions & 0 deletions examples/julia/ssp2_example.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
using SSP: init, solve!, adjoint_solve!
using SSP: Kernel, Pad, Convolve, Project
using .Kernel: conickernel
using .Pad: FillPadding, BoundaryPadding, Inner, PaddingProblem, DefaultPaddingAlgorithm
using .Convolve: DiscreteConvolutionProblem, FFTConvolution
using .Project: ProjectionProblem, SSP1_linear, SSP1, SSP2

using Random
using CairoMakie
using CairoMakie: colormap
using NLopt


Nx = Ny = 100
grid = (
range(-1, 1, length=Nx),
range(-1, 1, length=Ny),
)
# Random.seed!(42)
# design_vars = rand(Nx, Ny)
# design_vars = [sinpi(x) * sinpi(y) for (x, y) in Iterators.product(grid...)]
design_vars = let a = 0.5, b = 0.499
# Cassini oval
[((x^2 + y^2)^2 - 2a^2 * (x^2 - y^2) + a^4 - b^4) + 0.5 for (x, y) in Iterators.product(grid...)]
end
radius = 0.1

kernel = conickernel(grid, radius)

padprob = PaddingProblem(;
data = design_vars,
boundary = BoundaryPadding(size(kernel) .- 1, size(kernel) .- 1),
# boundary = FillPadding(1.0, size(kernel) .- 1, size(kernel) .- 1),
)
padalg = DefaultPaddingAlgorithm()
padsolver = init(padprob, padalg)
padsol = solve!(padsolver)

convprob = DiscreteConvolutionProblem(;
data = padsol.value,
kernel,
)

convalg = FFTConvolution()
convsolver = init(convprob, convalg)
convsol = solve!(convsolver)

depadprob = PaddingProblem(;
data = convsol.value,
boundary = Inner(size(kernel) .- 1, size(kernel) .- 1),
)
depadalg = DefaultPaddingAlgorithm()
depadsolver = init(depadprob, depadalg)
depadsol = solve!(depadsolver)

filtered_design_vars = depadsol.value

# projection points need not be the same as design variable grid
target_grid = (
range(-1, 1, length=Nx * 1),
range(-1, 1, length=Ny * 1),
)
target_points = vec(collect(Iterators.product(target_grid...)))
projprob = ProjectionProblem(;
rho_filtered=filtered_design_vars,
grid,
target_points,
beta = Inf,
eta = 0.5,
)
# projalg = SSP1_linear()
# projalg = SSP1()
projalg = SSP2()
projsolver = init(projprob, projalg)
projsol = solve!(projsolver)

projected_design_vars = projsol.value

let
fig = Figure()
ax1 = Axis(fig[1,1]; title = "design variables", aspect=DataAspect())
h1 = heatmap!(grid..., design_vars; colormap=colormap("grays"))
Colorbar(fig[1,2], h1)

ax2 = Axis(fig[1,3]; title = "SSP2 output", aspect=DataAspect())
h2 = heatmap!(target_grid..., reshape(projected_design_vars, length.(target_grid)); colormap=colormap("grays"))
Colorbar(fig[1,4], h2)
save("design.png", fig)
end

function fom(data, grid)
return sum(abs2, data) / length(data)
end
obj = fom(projected_design_vars, grid)

function adjoint_fom(adj_fom, data, grid)
adjoint_fom!(similar(data), adj_fom, data, grid)
end
function adjoint_fom!(adj_data, adj_fom, data, grid)
adj_data .= (adj_fom / length(data)) .* 2 .* data
return adj_data
end

adj_projsol = adjoint_fom(1.0, projected_design_vars, grid)

adj_projprob = adjoint_solve!(projsolver, adj_projsol, projsol.tape)
adj_depadsol = adj_projprob.rho_filtered
adj_depadprob = adjoint_solve!(depadsolver, adj_depadsol, depadsol.tape)
adj_convsol = adj_depadprob.data
adj_convprob = adjoint_solve!(convsolver, adj_convsol, convsol.tape)
adj_padsol = adj_convprob.data
adj_padprob = adjoint_solve!(padsolver, adj_padsol, padsol.tape)
adj_design_vars = adj_padprob.data

let
fig = Figure()
ax1 = Axis(fig[1,1]; title = "SSP2 output", aspect=DataAspect())
h1 = heatmap!(ax1, target_grid..., reshape(projected_design_vars, length.(target_grid)); colormap=colormap("grays"))
Colorbar(fig[1,2], h1)

ax2 = Axis(fig[1,3]; title = "design variables gradient", aspect=DataAspect())
h2 = heatmap!(ax2, grid..., adj_design_vars; colormap=colormap("RdBu"))
Colorbar(fig[1,4], h2)
save("design_gradient.png", fig)
end

fom_withgradient = let grid=grid, padsolver=padsolver, convsolver=convsolver, depadsolver=depadsolver, projsolver=projsolver, adj_projsol=adj_projsol
function (design_vars)

padsolver.data = design_vars
padsol = solve!(padsolver)
convsolver.data = padsol.value
convsol = solve!(convsolver)
depadsolver.data = convsol.value
depadsol = solve!(depadsolver)
projsolver.rho_filtered = depadsol.value
projsol = solve!(projsolver)

_fom = fom(projsol.value, grid)
adjoint_fom!(adj_projsol, 1.0, projsol.value, grid)

adj_projprob = adjoint_solve!(projsolver, adj_projsol, projsol.tape)
adj_depadsol = adj_projprob.rho_filtered
adj_depadprob = adjoint_solve!(depadsolver, adj_depadsol, depadsol.tape)
adj_convsol = adj_depadprob.data
adj_convprob = adjoint_solve!(convsolver, adj_convsol, convsol.tape)
adj_padsol = adj_convprob.data
adj_padprob = adjoint_solve!(padsolver, adj_padsol, padsol.tape)
adj_design_vars = adj_padprob.data
return _fom, adj_design_vars
end
end

h = 1e-5
h_index = (50, 50)
# h_index = (38, 50)
perturb = zero(design_vars)
perturb[h_index...] = h
fom_ph, = fom_withgradient(design_vars + perturb)
fom_mh, = fom_withgradient(design_vars - perturb)
dfomdh_fd = (fom_ph - fom_mh) / 2h

fom_val, adj_design_vars = fom_withgradient(design_vars)
dfomdh = adj_design_vars[h_index...]
@show dfomdh_fd dfomdh

opt = NLopt.Opt(:LD_CCSAQ, length(design_vars))
evaluation_history = Float64[]
my_objective_fn = let fom_withgradient=fom_withgradient, evaluation_history=evaluation_history, design_vars=design_vars
function (x, grad)
val, adj_design = fom_withgradient(reshape(x, size(design_vars)))
if !isempty(grad)
copy!(grad, vec(adj_design))
end
push!(evaluation_history, val)
return val
end
end
NLopt.min_objective!(opt, my_objective_fn)
NLopt.maxeval!(opt, 50)
fmax, xmax, ret = NLopt.optimize(opt, vec(design_vars))

let
padsolver.data = reshape(xmax, size(design_vars))
padsol = solve!(padsolver)
convsolver.data = padsol.value
convsol = solve!(convsolver)
depadsolver.data = convsol.value
depadsol = solve!(depadsolver)
projsolver.rho_filtered = depadsol.value
projsol = solve!(projsolver)

fig = Figure()
ax1 = Axis(fig[1,1]; title = "Objective history", yscale=log10, limits = (nothing, (1e-16, 1e1)))
h1 = scatterlines!(ax1, evaluation_history)

ax2 = Axis(fig[1,2]; title = "Final SSP2 design", aspect=DataAspect())
h2 = heatmap!(target_grid..., reshape(projsol.value, length.(target_grid)); colormap=colormap("grays"))
Colorbar(fig[1,3], h2)
save("optimization.png", fig)
end
97 changes: 97 additions & 0 deletions examples/julia/ssp_comparison.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
using SSP: conic_filter, ssp1_linear, ssp1, ssp2

using Random
using CairoMakie
using CairoMakie: colormap
using NLopt
using Zygote


Nx = Ny = 100
grid = (
range(-1, 1, length=Nx),
range(-1, 1, length=Ny),
)
# Random.seed!(42)
# design_vars = rand(Nx, Ny)
# design_vars = [sinpi(x) * sinpi(y) for (x, y) in Iterators.product(grid...)]
design_vars = let a = 0.5, b = 0.499
# Cassini oval
[((x^2 + y^2)^2 - 2a^2 * (x^2 - y^2) + a^4 - b^4) + 0.5 for (x, y) in Iterators.product(grid...)]
end
radius = 0.1
beta = Inf
eta = 0.5
ssp_algs = (ssp1_linear, ssp1, ssp2)

ssp_projections = map(ssp_algs) do ssp
rho_filtered = conic_filter(design_vars, radius, grid)
rho_projected = ssp(rho_filtered, beta, eta, grid)
return rho_projected
end

let
fig = Figure(size = (1200, 400))
for (i, (ssp, rho_projected)) in enumerate(zip(ssp_algs, ssp_projections))
ax = Axis(fig[1,2i-1]; title = "$(string(nameof(ssp))) projection", aspect=DataAspect())
h = heatmap!(grid..., rho_projected; colormap=colormap("grays"))
Colorbar(fig[1,2i], h)
end
save("projection_comparison.png", fig)
end

function figure_of_merit(rho_projected)
sum(abs2, rho_projected) / length(rho_projected)
end

ssp_projection_gradients = map(ssp_algs) do ssp
design_vars_gradient = Zygote.gradient(design_vars) do design_vars
rho_filtered = conic_filter(design_vars, radius, grid)
rho_projected = ssp(rho_filtered, beta, eta, grid)
return figure_of_merit(rho_projected)
end
return design_vars_gradient[1]
end

let
fig = Figure(size = (1200, 400))
for (i, (ssp, rho_projected_gradient)) in enumerate(zip(ssp_algs, ssp_projection_gradients))
ax = Axis(fig[1,2i-1]; title = "$(string(nameof(ssp))) projection gradient", aspect=DataAspect())
h = heatmap!(grid..., rho_projected_gradient; colormap=colormap("RdBu"))
Colorbar(fig[1,2i], h)
end
save("projection_gradient_comparison.png", fig)
end

ssp_optimization_histories = map(ssp_algs) do ssp
opt = NLopt.Opt(:LD_CCSAQ, length(design_vars))
evaluation_history = Float64[]
my_objective_fn = let evaluation_history=evaluation_history, design_vars=design_vars
function (x, grad)
fom, adj_design = Zygote.withgradient(x) do x
rho_filtered = conic_filter(reshape(x, size(design_vars)), radius, grid)
rho_projected = ssp(rho_filtered, beta, eta, grid)
return figure_of_merit(rho_projected)
end
if !isempty(grad)
copy!(grad, vec(adj_design[1]))
end
push!(evaluation_history, fom)
return fom
end
end
NLopt.min_objective!(opt, my_objective_fn)
NLopt.maxeval!(opt, 50)
fmax, xmax, ret = NLopt.optimize(opt, vec(design_vars))
return evaluation_history
end

let
fig = Figure()
ax = Axis(fig[1,1]; title = "Optimization history", yscale=log10, limits = (nothing, (1e-16, 1e1)))
for (i, (ssp, evaluation_history)) in enumerate(zip(ssp_algs, ssp_optimization_histories))
scatterlines!(ax, evaluation_history; label=string(nameof(ssp)))
end
Legend(fig[1,2], ax)
save("evaluation_history_comparison.png", fig)
end
19 changes: 19 additions & 0 deletions src/julia/SSP/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name = "SSP"
uuid = "e5b5d2ee-15bb-40cc-a0da-b305b842b7a8"
version = "0.1.0"
authors = ["Lorenzo Van Munoz <lorenzo@vanmunoz.com>"]

[deps]
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FastInterpolations = "9ea80cae-fc13-4c00-8066-6eaedb12f34b"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[extensions]
SSPChainRulesCoreExt = "ChainRulesCore"

[compat]
ChainRulesCore = "1"
FFTW = "1.10.0"
FastInterpolations = "0.4.4"
33 changes: 33 additions & 0 deletions src/julia/SSP/ext/SSPChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
module SSPChainRulesCoreExt

import ChainRulesCore: rrule, NoTangent
using SSP: conic_filter, ssp1_linear, ssp1, ssp2, Project, conic_filter_withsolver, ssp_withsolver, conic_filter_rrule, ssp_rrule

function rrule(::typeof(conic_filter), data, radius, grid)
rho_filtered, solvers = conic_filter_withsolver(data, radius, grid)
function conic_filter_pullback(adj_rho_filtered)
adj_data = conic_filter_rrule(adj_rho_filtered, solvers...)
return NoTangent(), adj_data, NoTangent(), NoTangent()
end
return rho_filtered, conic_filter_pullback
end

function _ssp_rrule(alg, rho_filtered, beta, eta, grid)
rho_projected, solver = ssp_withsolver(alg, rho_filtered, beta, eta, grid)
function ssp_pullback(adj_rho_projected)
adj_rho_filtered = ssp_rrule(adj_rho_projected, solver)
return NoTangent(), adj_rho_filtered, NoTangent(), NoTangent(), NoTangent()
end
return rho_projected, ssp_pullback
end
function rrule(::typeof(ssp1_linear), rho_filtered, beta, eta, grid; kws...)
_ssp_rrule(Project.SSP1_linear(; kws...), rho_filtered, beta, eta, grid)
end
function rrule(::typeof(ssp1), rho_filtered, beta, eta, grid; kws...)
_ssp_rrule(Project.SSP1(; kws...), rho_filtered, beta, eta, grid)
end
function rrule(::typeof(ssp2), rho_filtered, beta, eta, grid; kws...)
_ssp_rrule(Project.SSP2(; kws...), rho_filtered, beta, eta, grid)
end

end
20 changes: 20 additions & 0 deletions src/julia/SSP/src/SSP.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module SSP

include("definitions.jl")
public init, init!, solve, solve!, adjoint_solve, adjoint_solve!

include("pad.jl")
public Pad
include("kernel.jl")
public Kernel
include("convolve.jl")
public Convolve
include("interpolate.jl")
public Interpolate
include("project.jl")
public Project

include("pythonic_api.jl")
public conic_filter, ssp1_linear, ssp1, ssp2

end
Loading
Loading