Category theory for legible software and corrigible world modeling

Kris Brown

Topos Institute

(press s for speaker notes)

4/24/24

About the Topos Institute

Mission: to shape technology for public benefit by advancing sciences of connection and integration.

Three pillars of our work, from theory to practice to social impact:

  1. Collaborative modeling in science and engineering
  2. Collective intelligence, including theories of systems and interaction
  3. Research ethics


Why is this relevant to AI safety?

  1. We can use verified, interpretable models (AI or otherwise) without losing too much economic value
    • Reduce incentive to hand authority over to uninterpretable black boxes.

Why is this relevant to AI safety?

  1. We can use verified, interpretable models (AI or otherwise) without losing too much economic value
    • Reduce incentive to hand authority over to uninterpretable black boxes.
  2. AI can target domains with a well-specified syntax and semantics allows for legibility and explanation
    • Davidad’s ARIA proposal: world-models as autonomous AI output
    • Anna Leshinaskaya’s talk on moral decisionmaking w/ combinatorial grammar
    • Yoshua Bengio’s emphasis of separating model from implementation

Why is this relevant to AI safety?

  1. We can use verified, interpretable models (AI or otherwise) without losing too much economic value
    • Reduce incentive to hand authority over to uninterpretable black boxes.
  2. AI can target domains with a well-specified syntax and semantics allows for legibility and explanation
    • Davidad’s ARIA proposal: world-models as autonomous AI output
    • Anna Leshinaskaya’s talk on moral decisionmaking w/ combinatorial grammar
    • Yoshua Bengio’s emphasis of separating model from implementation
  3. We facilitate collaboration between diverse stakeholders, better alignment of values
    • Opens opportunities for more equitable representation of values from all stakeholders (rather than just those who’ve mastered dependent type theory)

Outline

I. Future of software engineering

II. Mini examples of the future paradigm

III. Intuition for how structural mathematics works

Stages of development for software engineering

  1. We solve problems one at a time, as they come.

This starts feeling repetitive.

  1. Abstractions act as a solution to an entire class of problems.

This feels repetitive insofar as we feel the problem classes are related.

Common abstractions are things like functions / datatypes / scripts in our language.

  1. We migrate abstractions to allow them to address changes in problem specification.

Future paradigm: mathematically-principled + automatic abstraction reuse

Understand structure of the abstraction + implementation + the relationship between them.

Three examples of abstractions

Abstraction: databases (SQL)

CREATE SCHEMA COLORGRAPH;
CREATE TABLE VERT (ID PRIMARY KEY, COLOR VARCHAR);
CREATE TABLE EDGE (ID PRIMARY KEY, 
                   FOREIGN KEY TGT REFERENCES VERT(ID),
                   FOREIGN KEY SRC REFERENCES VERT(ID));


INSERT INTO VERT VALUES ('B')
INSERT INTO VERT VALUES ('R')
INSERT INTO VERT VALUES ('Y')
INSERT INTO EDGE VALUES (1,2)
INSERT INTO EDGE VALUES (2,3)
INSERT INTO EDGE VALUES (3,1)

Abstraction: databases (Python)

class ColorGraph():
  def __init__(
    self, 
    edges:Set{Tuple{int, int}},
    vertices:Set{int},
    colors:Dict{int, str}):
    ...
SchemaColorGraph = Schema( 
  Ob=["V", "E"], 
  Hom=[("src", "E", "V"), (tgt, "E", "V")],
  AttrType=["Color"],
  Attr=[("color", "V", "Color")]
)
ColorGraph = acset_type(SchemaColorGraph, [str])


def triangle(graph=ColorGraph()):
  a, b, c = [add_vertex(graph, color)  
             for color in "BRY"]
  [add_edge(graph, s, t) for 
    (s, t) in [[a,b],[b,c],[c,a]]]
  return graph

New problem expressed in terms of old problem.

Gluing: side-by-side comparison

"""Modify to allow it to overlap existing vertex IDs"""
def triangle(graph, v1=nothing, v2=nothing, v3=nothing):
  vs = [add_vertex(graph, color) if isnothing(v) else v
        for (v, color) in [v1=>"B", v2=>"R", v3=>"Y"]]
  for i,j in [[1,2],[2,3],[3,1]]:
    add_edge(graph, vs[i], vs[j])
  return vs

"""Now expressible as an abstraction of `triangle!`"""
def big_tri(graph):
  blue, red, _ = triangle(graph)
  _, _, yellow = triangle(graph, blue)
  triangle(graph, nothing, red, yellow)
def replace(graph, fk:str, find:int, repl:int):
  # TODO

def big_tri(graph):
  triangle(graph)
  triangle(graph)
  triangle(graph)
  for fk in [:src, :tgt]:
    replace(graph, fk, 5, 2)
    replace(graph, fk, 7, 1)
    replace(graph, fk, 9, 5)

  [delete_vertex(graph, v) for v in [5,7,9]]

Gluing: side-by-side comparison

"""Modify to allow it to overlap existing vertex IDs"""
def triangle(graph, v1=nothing, v2=nothing, v3=nothing):
  vs = [add_vertex(graph, color) if isnothing(v) else v
        for (v, color) in [v1=>"B", v2=>"R", v3=>"Y"]]
  for i,j in [[1,2],[2,3],[3,1]]:
    add_edge(graph, vs[i], vs[j])
  return vs

"""Now expressible as an abstraction of `triangle!`"""
def big_tri(graph):
  blue, red, _ = triangle(graph)
  _, _, yellow = triangle(graph, blue)
  triangle(graph, nothing, red, yellow)
def replace(graph, fk:str, find:int, repl:int):
  # TODO

def big_tri(graph):
  triangle(graph)
  triangle(graph)
  triangle(graph)
  for fk in [:src, :tgt]:
    replace(graph, fk, 5, 2)
    replace(graph, fk, 7, 1)
    replace(graph, fk, 9, 5)

  [delete_vertex(graph, v) for v in [5,7,9]]


AlgebraicJulia

# Building blocks
T1,T2,T3 = [triangle() for _ in 1:3]
b,r,y = [add_vertex(ColorGraph(), c) for c in "BRY"]

# Relation of building blocks to each other
pattern = Relation(
  b=["t1", "t2"], r=["t2", "t3"]; y=["t3", "t1"]
)

# Composition
big_tri = glue(pattern, T1=T1, T2=T2, T3=T3, 
                        b=b, r=r, y=y)

Gluing: side-by-side comparison

"""Modify to allow it to overlap existing vertex IDs"""
def triangle(graph, v1=nothing, v2=nothing, v3=nothing):
  vs = [add_vertex(graph, color) if isnothing(v) else v
        for (v, color) in [v1=>"B", v2=>"R", v3=>"Y"]]
  for i,j in [[1,2],[2,3],[3,1]]:
    add_edge(graph, vs[i], vs[j])
  return vs

"""Now expressible as an abstraction of `triangle!`"""
def big_tri(graph):
  blue, red, _ = triangle(graph)
  _, _, yellow = triangle(graph, blue)
  triangle(graph, nothing, red, yellow)
def replace(graph, fk:str, find:int, repl:int):
  # TODO

def big_tri(graph):
  triangle(graph)
  triangle(graph)
  triangle(graph)
  for fk in [:src, :tgt]:
    replace(graph, fk, 5, 2)
    replace(graph, fk, 7, 1)
    replace(graph, fk, 9, 5)

  [delete_vertex(graph, v) for v in [5,7,9]]


AlgebraicJulia

# Building blocks
T1,T2,T3 = [triangle() for _ in 1:3]
b,r,y = [add_vertex(ColorGraph(), c) for c in "BRY"]

# Relation of building blocks to each other
pattern = Relation(
  b=["t1", "t2"], r=["t2", "t3"]; y=["t3", "t1"]
)

# Composition
big_tri = glue(pattern, T1=T1, T2=T2, T3=T3, 
                        b=b, r=r, y=y)

Abstraction: Petri nets

class PetriNet():
  def __init__(self, 
    S:Set[str], 
    T:Dict[str,
           Tuple[List[str], 
                 List[str]]]
    ): 
    ...
SchemaPetriNet = Schema(
  Ob=["S", "T", "I", "O"],
  Hom=[("is", "I", "S"), ("os", "O", "S"),
       ("it", "I", "T"), ("ot", "O", "T")]
)
PetriNet = acset_type(SchemaPetriNet)

sis = PetriNet({"S", "I"},
               {"inf": (["S","I"], ["I","I"]),
                "rec": (["I"],     ["S"])})

def ODE(p:PetriNet, rates, init, time): ...
def stochastic(p:PetriNet, rates, init, time): ...

New problem expressed in terms of old problem.

Multiplying: side-by-side comparison

def prepend(pre): return lambda s: pre + "_" + s

US, EU, E, W = 
  map(prepend, ["US", "EU", "east", "west"])

def mult_EU(petri:PetriNet):
  res = PetriNet()
  for state in petri.S:
    us = add_state(res, US(state))
    eu = add_state(res, EU(state))
    add_transition(res, E(state), [us]=>[eu])
    add_transition(res, W(state), [eu]=>[us])
  
  for (T, (I, O)) in petri.T:
    add_transition(res, US(T), US.(I)=>US.(O))
    add_transition(res, EU(T), EU.(I)=>EU.(O))
  
  return res
def multiply(r1:PetriNet, r2:PetriNet):
  res = PetriNet()
  for (s1, s2) in itertools.product(r1.S, r2.S):
    add_state(res, f"{s1}_{s2}")

  for rx1 in r1.T:
    for s2 in r2.S:
      add_transition(res, rename_rxn(rx1, s2))

  for rx2 in r2.T:
    for s1 in r1.S:
      add_transition(res, rename_rxn(rx2, s1))

  return RxnNet(rs)

def rename_rxn(rxn, species:str): # TODO

Multiplying: back to basics

Multiplying: back to basics

Multiplying: back to basics

Multiplying: side-by-side comparison

def prepend(pre): return lambda s: pre + "_" + s

US, EU, E, W = 
  map(prepend, ["US", "EU", "east", "west"])

def mult_EU(petri:PetriNet):
  res = PetriNet()
  for state in petri.S:
    us = add_state(res, US(state))
    eu = add_state(res, EU(state))
    add_transition(res, E(state), [us]=>[eu])
    add_transition(res, W(state), [eu]=>[us])
  
  for (T, (I, O)) in petri.T:
    add_transition(res, US(T), US.(I)=>US.(O))
    add_transition(res, EU(T), EU.(I)=>EU.(O))
  
  return res
def multiply(r1:PetriNet, r2:PetriNet):
  res = PetriNet()
  for (s1, s2) in itertools.product(r1.S, r2.S):
    add_state(res, f"{s1}_{s2}")

  for rx1 in r1.T:
    for s2 in r2.S:
      add_transition(res, rename_rxn(rx1, s2))

  for rx2 in r2.T:
    for s1 in r1.S:
      add_transition(res, rename_rxn(rx2, s1))

  return RxnNet(rs)

def rename_rxn(rxn, species:str): # TODO

AlgebraicJulia

P_base = PetriNet( 
  S=1, T=2, I=2, O=2, is=1, os=1, it=[1,2], ot=[1,2] 
)

# Label P1 and P2 transitions as green or orange
h1 = homomorphism(P1, P_base, init=...)
h2 = homomorphism(P2, P_base, init=...)

# Composition of old problems
US_EU_SIS = multiply(h1, h2)

Multiplying: side-by-side comparison

def prepend(pre): return lambda s: pre + "_" + s

US, EU, E, W = 
  map(prepend, ["US", "EU", "east", "west"])

def mult_EU(petri:PetriNet):
  res = PetriNet()
  for state in petri.S:
    us = add_state(res, US(state))
    eu = add_state(res, EU(state))
    add_transition(res, E(state), [us]=>[eu])
    add_transition(res, W(state), [eu]=>[us])
  
  for (T, (I, O)) in petri.T:
    add_transition(res, US(T), US.(I)=>US.(O))
    add_transition(res, EU(T), EU.(I)=>EU.(O))
  
  return res
def multiply(r1:PetriNet, r2:PetriNet):
  res = PetriNet()
  for (s1, s2) in itertools.product(r1.S, r2.S):
    add_state(res, f"{s1}_{s2}")

  for rx1 in r1.T:
    for s2 in r2.S:
      add_transition(res, rename_rxn(rx1, s2))

  for rx2 in r2.T:
    for s1 in r1.S:
      add_transition(res, rename_rxn(rx2, s1))

  return RxnNet(rs)

def rename_rxn(rxn, species:str): # TODO

AlgebraicJulia

P_base = PetriNet( 
  S=1, T=2, I=2, O=2, is=1, os=1, it=[1,2], ot=[1,2] 
)

# Label P1 and P2 transitions as green or orange
h1 = homomorphism(P1, P_base, init=...)
h2 = homomorphism(P2, P_base, init=...)

# Composition of old problems
US_EU_SIS = multiply(h1, h2)

Abstraction: Wiring diagrams

def goods_flow(steps:int, initial:List[float], 
               input:Callable):
  output = []
  trade₁, trade₂ = initial
  for step in 1:steps:
    trade₁, (trade₂, out) = [
      trader₁(input(step), trade₂),
      trader₂(trade₁)
    ]
    output.append(out)
  
  return output
"""
A + B = C
±√(C) = (B, D) 

Find all integer solutions (A, D) up to `n`
"""
def solve_constraint(n::Int):
  def cond(AD):
    A, D = AD
    B = -D
    C = A + B
    return C == D^2
  return filter(cond, filter(Iterators.product(
    range(n), range(n))))

Abstraction: Wiring diagrams

New problem expressed in terms of old problem.

Substituting: side-by-side comparison

def goods_flow(steps:int, initial:List[float], 
               input:Callable):
  output = []
  trade₁, trade₂ = initial
  for step in 1:steps:
    trade₁, (trade₂, out) = [
      trader₁(input(step), trade₂),
      trader₂(trade₁)
    ]
    output.append(out)
  
  return output



\(\Rightarrow\)

def goods_flow(steps:int, initial:List[float], 
               input::Callable):
  output = []
  trade2, trade3, trade4 = initial
  for step in 1:steps:
    trade2, trade3, (trade4, out) = [ 
      trader2(trade3),
      trader3(input(step), trade2),
      trader4(trade3, trade2),
    ]
    output.append(out)

  return output

Substituting: side-by-side comparison

"""
A + B = C
±√(C) = (B, D) 

Find all integer solutions (A, D) up to `n`
"""
def solve_constraint(n::Int):
  def cond(AD):
    A, D = AD
    B = -D
    C = A + B
    return C == D^2
  return filter(cond, filter(Iterators.product(
    range(n), range(n))))



\(\Rightarrow\)

"""
A + B = C
C * B = D
±√(D) = (B, E) 

Find all integer solutions (A, E) up to `n`
"""
def solve_constraint(n::Int)
  def constr(AE):
    A, E = AE 
    B = -E
    C = A + B
    D = C * B 
    return D == E^2

  return filter(constr, itertools.product(
      range(n), range(n)))

Substituting: side-by-side comparison

sys1 = Trader(2, 1, lambda x,y: ...)
sys2 = Trader(1, 2, lambda x: ...)
sys3 = Trader(2, 1, lambda x,y: ...)
sys4 = Trader(2, 1, lambda x,y: ...)
sys1 = Relation(2, 1, lambda xs,ys: xs[1] + xs[2] == ys[1])
sys2 = Relation(1, 2, lambda xs,ys: sqrt(xs[1]) == ys[1] 
                                    and ys[1]== -ys[2])
sys3 = sys1
sys4 = Relation(2, 1, lambda xs,ys: xs[1] * xs[2] == ys[1])


AlgebraicJulia

wd12 = WiringDiagram(["A"], ["A"])
b1 = add_box(wd12, Box("op1", ["A","D"], ["C"]))
b2 = add_box(wd12, Box("op2", ["C"], ["D","A"]))
add_wires(wd12, [
  (input_id(wd12), 1),  (b1, 2),
  (b1, 1),              (b2, 1),
  (b2, 1),              (b1, 1),
  (b2, 2),              (output_id(wd12), 1)
])


model   = oapply(wd12, [sys1,sys2])
wd234   = ocompose(wd12, 1, wd34)
new_mod = oapply(wd234,[sys1,sys3,sys4])

Changing from one abstraction to another

New problem expressed in terms of old problem.

Changing from one abstraction to another

New problem expressed in terms of old problem.

Changing abstractions: side-by-side

"""
Species and transitions are vertices.
Inputs and outputs are edges.
"""
def petri_to_graph(p:PetriNet):
  grph = Graph()
  vs = {s: add_vertex(grph) for s in p.S}
  for (t, (i, o)) in pairs(p.T):
    t = add_vertex(grph) 
    for e in i:
      add_edge(grph, vs[e], t)
    for e in o:
      add_edge(grph, t, vs[e])

  return grph
"""
Vertices are species.
Edges are transitions with one input and
one output.
"""
def graph_to_petri(g:Graph):
  P = PetriNet()
  for i in G.vertices:
    add_species(P, f"v{i}") 

  for (e, (s, t)) in enumerate(G.edges):
    add_transition(f"e{e}",
                   ([P.S[s]], [P.S[t]]))

  return P


AlgebraicJulia

F = FinFunctor(SchPetri, SchGraph, 
    S = "V", T = "V", I => "E", O => "E")

my_graph = SigmaMigration(F, my_petri)
F = FinFunctor(SchPetri, SchGraph,
  S => "V", T => "E", I => "E", O => "E")
  
my_petri = DeltaMigration(F, my_graph)

Replacing code with data

Common theme: writing code vs declaring relationships between abstractions.

Problem Python solution AlgebraicJulia solution
Different pieces of a model need to be glued together. Write a script which does the gluing or modifies how pieces are constructed. Declare how overlap relates to the building blocks. (colimits)
Different aspects of a model need to be combined / a distinction is needed. Write a script which creates copies of one aspect for every part of the other aspect. Declare how the different aspects interact with each other. (limits)
We want to integrate systems at different levels of granularity. Refactor the original code to incorporate the more detailed subsystem. Separate syntax/semantics. Declare how the part relates to the whole at syntax level. (operads)
We make a new assumption and want to migrate old knowledge into our new understanding. Write a script to convert old data into updated data. Declare how the new way of seeing the world (i.e. schema) is related to the old way. (data migration)

Why is this relevant to AI safety?

  1. We can use verified, interpretable models (AI or otherwise) without losing too much economic value
    • Reduce incentive to hand authority over to uninterpretable black boxes.

Why is this relevant to AI safety?

  1. We can use verified, interpretable models (AI or otherwise) without losing too much economic value
    • Reduce incentive to hand authority over to uninterpretable black boxes.
  2. AI can target domains with a well-specified syntax and semantics allows for legibility and explanation
    • Davidad’s ARIA proposal: world-models as autonomous AI output
    • Anna Leshinaskaya’s talk on moral decisionmaking w/ combinatorial grammar
    • Yoshua Bengio’s emphasis of separating model from implementation

Why is this relevant to AI safety?

  1. We can use verified, interpretable models (AI or otherwise) without losing too much economic value
    • Reduce incentive to hand authority over to uninterpretable black boxes.
  2. AI can target domains with a well-specified syntax and semantics allows for legibility and explanation
    • Davidad’s ARIA proposal: world-models as autonomous AI output
    • Anna Leshinaskaya’s talk on moral decisionmaking w/ combinatorial grammar
    • Yoshua Bengio’s emphasis of separating model from implementation
  3. We facilitate collaboration between diverse stakeholders, better alignment of values
    • Opens opportunities for more equitable representation of values from all stakeholders (rather than just those who’ve mastered dependent type theory)

Why does it work?

Informal definition: a category is a bunch of things that are related to each other

Key intuition: category theory is concerned with universal properties.

  • These change something that we once thought of as a property of an object into a kind of relationship that object has towards related objects.

Example universal property: emptiness

Consider mathematical sets which are related to each other via functions.

Definition in terms of internal properties

The empty set is the unique set which has no elements in it.

But if we we look at how the empty set relates to all the other sets, we’ll eventually notice something about these relations.

Definition in terms of external relationships (universal properties)

The empty set is the unique set which has exactly one function into every other set.

Example universal property: emptiness

Consider colored graphs related to each other via vertex mappings which preserve color and edges.

Definition in terms of internal properties

The empty graph uniquely has no vertices nor edges in it.

But if we we look at how it relates to all the other graphs, we’ll eventually notice something characteristic.

Definition in terms of external relationships (universal properties)

The empty graph is the unique graph which has exactly one graph mapping into every other graph.

Universal properties and generalizable abstractions

Category theory enforces good conceptual hygeine - one isn’t allowed to depend on “implementation details” of the things which feature in its definitions.

This underlies the ability of models built in AlgebraicJulia to be extended and generalized without requiring messy code refactor.

Takeaways

CT is useful for the same reason interfaces are generally useful.

In particular, CT provides generalized1 notions of:

  • multiplication / multidimensionality
  • adding things side-by-side
  • gluing things along a common boundary
  • looking for a pattern
  • find-and-replace a pattern
  • parallel vs sequential processes
  • Mad Libs style filling in of wildcards
  • Zero and One
  • “Open” systems
  • Subsystems
  • Enforcing equations
  • Symmetry

These abstractions all fit very nicely with each other:

  • conceptually built out of basic ideas of limits, colimits, and morphisms.

We can use them to replace a large amount of our code with high level, conceptual data.

Decapodes.jl: multiphysics modeling

"""Define the multiphysics"""
Diffusion = @decapode DiffusionQuantities begin
  C::Form0{X}
  ϕ::Form1{X}
  ϕ == k(d₀{X}(C))   # Fick's first law
end
Advection = @decapode DiffusionQuantities begin
  C::Form0{X}
  (V, ϕ)::Form1{X}
  ϕ == ∧₀₁{X}(C,V)
end
Superposition = @decapode DiffusionQuantities begin
  (C, Ċ)::Form0{X}
  (ϕ, ϕ₁, ϕ₂)::Form1{X}
  ϕ == ϕ₁ + ϕ₂
== ⋆₀⁻¹{X}(dual_d₁{X}(⋆₁{X}(ϕ)))
  ∂ₜ{Form0{X}}(C) ==
end
compose_diff_adv = @relation (C, V) begin
  diffusion(C, ϕ₁)
  advection(C, ϕ₂, V)
  superposition(ϕ₁, ϕ₂, ϕ, C)
end
"""Geometry"""
mesh = loadmesh(Torus_30x10()) 
"""Assign semantics to operators"""
funcs = sym2func(mesh)
funcs[:k] = Dict(:operator => 0.05 * I(ne(mesh)), 
  :type => MatrixFunc())
funcs[:⋆₁] = Dict(:operator => (Val{1}, mesh, 
  hodge=DiagonalHodge()), :type => MatrixFunc());
funcs[:∧₀₁] = Dict(:operator => (r, c, v) -> r .= 
  (Tuple{0,1}, mesh, c, v), :type => InPlaceFunc())

Decapodes.jl: simulation

Resources

Hidden slide: this slide intentionally left blank

Hidden slide: Why Category Theory?

Focuses on relationships between things without talking about the things themselves.

Invented in the 1940’s to connect different branches of math.

A category consists of objects and morphisms (arrows).

  • We don’t need to know anything about the objects.
  • Compose \(A \rightarrow B\) and \(B \rightarrow C\) to get \(A \rightarrow C\).
  • Like a graph, but we care about paths, not edges.

CT studies certain shapes of combinations of arrows.

  • These can be local shapes, e.g. a span:    \(\huge \cdot \leftarrow \cdot \rightarrow \cdot\)

  • These can be global, e.g. an initial object: \(\huge \boxed{\cdot \rightarrow \cdot\rightarrow \cdot \rightarrow \dots}\)

Hidden slide: Applied Category Theory?

Compare to interfaces in computer science:

  • declare that some collection of things are related in a particular way without saying what they are.
interface Queue{A}

size(q:Queue) -> Int 
empty(q:Queue) -> Bool 
put(q:Queue, a:A) -> ()
get(q:Queue) -> A 

In some sense a category is just a particular interface.

interface Category{Ob,Arr}

dom(a:Arr) -> Ob 
codom(a:Arr) -> Ob 
compose(a:Arr, b::Arr) -> Arr 
id(o:Ob) -> Arr
  • Category of sets and functions
  • Category of sets and subsets
  • Category of \(\mathbb{Z}\) and \(\leq\)
  • Category of categories and functors
  • Category of chemical reaction networks
  • Category of chemical structures
  • Category of datasets
  • Category of datatypes and programs

CT is also the study of interfaces in general. It knows which are good ones.