wrote new unify-function that always returns minimal set of reductions

This commit is contained in:
Joeri Exelmans 2025-05-23 14:04:39 +02:00
parent 47786ae792
commit 6fd4a4c0e1
3 changed files with 278 additions and 2 deletions

168
lib/generics/unify.js Normal file
View file

@ -0,0 +1,168 @@
import { eqType } from "../primitives/type.js";
import { isTypeVar } from "../primitives/typevars.js";
import { prettyT } from "../util/pretty.js";
import { indent, zip } from "../util/util.js";
import { compareStrings } from "../compare/primitives.js";
import { getHumanReadableName } from "../primitives/symbol.js";
import { occurring, substitute } from "./generics.js";
export const prettyS = (typevar, type) => {
return `${getHumanReadableName(typevar)}${prettyT(type)}`;
}
export const prettySS = (rUni) => '{'+[...rUni].map(([symbol,type]) => `${prettyS(symbol,type)}`).join(', ')+'}';
export class IncompabibleTypesError extends Error {
constructor(typeA, typeB, nestedErr) {
const msg = `\nIncompatible types: ${prettyT(typeA)} and ${prettyT(typeB)}`;
if (nestedErr) {
super(msg + indent(nestedErr.message, 2));
}
else {
super(msg);
}
}
}
export class SubstitutionCycle extends Error {
constructor(typevar, type) {
super(`\nSubstitution cycle: ${getHumanReadableName(typevar)}${prettyT(type)}`);
}
}
export const subsitutionsEqual = (m1,m2) => {
if (m1.size !== m2.size ) return false;
for (const [key1,type1] of m1) {
if (!eqType(m2.get(key1))(type1)) return false;
}
return true;
};
// Partial ordering between types
// - deep-equal types are equal (e.g., Int == Int)
// - non-typevars are smaller than typevars (e.g., Int < a)
// - typevars with smaller letters are smaller (e.g., a < b)
// - if the symbols match, and if an ordering exists between all the parameters, then we take the ordering of the first non-equal parameter (e.g., (a,Int) < (b,Int))
// returns: Ordering | undefined
const partialCompareTypes = (typeA, typeB) => {
if (isTypeVar(typeA)) {
if (isTypeVar(typeB)) {
return compareStrings(typeA.symbol)(typeB.symbol);
}
if (!isTypeVar(typeB)) {
return 1;
}
}
if (isTypeVar(typeB)) {
// console.log(typeB, 'is a typevar');
return -1; // typeB is typevar, typeA is not
}
if (typeA.symbol === typeB.symbol) {
let result = 0;
for (const [paramA, paramB] of zip(typeA.params, typeB.params)) {
const paramCmp = partialCompareTypes(paramA(typeA), paramB(typeB));
if (paramCmp === undefined) return; // no ordering
result ||= paramCmp;
}
return result;
}
}
const checkCycle = (typevar, type) => {
if (occurring(type).has(typevar)) {
throw new SubstitutionCycle(typevar, type);
}
}
const addReduce = (substitutions, typevar, type) => {
// console.log('add ', prettyS(typevar, type));
substitutions.set(typevar, type);
}
const attemptReduce = (substitutions, typevar, type) => {
// assuming 'substitutions' is already reduced as much as possible,
// substitute all typevars in our type with the existing substitutions
const substType = substitute(type, substitutions);
// Check for cycles. For instance, the substitution
// a ↦ [a]
// is forbidden. Not sure if this is too strict, because on the other hand, we do support recursive types...
checkCycle(typevar, substType);
const overlappingType = substitutions.get(typevar);
if (overlappingType) {
const uni = unify(overlappingType, substType);
for (const [typevar, type] of uni) {
attemptReduce(substitutions, typevar, type);
}
const cmp = partialCompareTypes(substType, overlappingType);
if (cmp === -1) {
// our type "wins" (smaller than the existing type)
addReduce(substitutions, typevar, substType);
}
else {
// other type "wins" -> don't do anything
}
}
else {
// no overlap
addReduce(substitutions, typevar, substType);
}
}
export const unify = (typeA, typeB) => {
const substitutions = new Map();
try {
if (isTypeVar(typeA)) {
if (isTypeVar(typeB)) {
const cmp = partialCompareTypes(typeA, typeB);
if (cmp === -1) {
// console.log(prettyT(typeA), 'is smaller than', prettyT(typeB));
return new Map([[typeB.symbol, typeA]]);
}
if (cmp === 1) {
// console.log(prettyT(typeB), 'is smaller than', prettyT(typeA));
return new Map([[typeA.symbol, typeB]]);
}
return new Map(); // typevars are equal
}
// A is typevar, B is not
checkCycle(typeA.symbol, typeB);
return new Map([[typeA.symbol, typeB]]);
}
if (isTypeVar(typeB)) {
// B is typevar, A is not
checkCycle(typeB.symbol, typeA);
return new Map([[typeB.symbol, typeA]]);
}
// A and B are not typevars
if (typeA.symbol !== typeB.symbol) {
throw new IncompabibleTypesError(typeA, typeB);
}
const unifiedParams = zip(typeA.params, typeB.params)
.map(([getParamA, getParamB]) => {
const paramA = getParamA(typeA);
const paramB = getParamB(typeB);
// console.log('request...');
return unify(paramA, paramB);
});
// merge substitutions
unifiedParams.forEach(subst => {
for (const [typevar, type] of subst) {
attemptReduce(substitutions, typevar, type);
}
});
}
catch (e) {
if (e instanceof SubstitutionCycle) {
throw new IncompabibleTypesError(typeA, typeB, e);
}
if (e instanceof IncompabibleTypesError) {
// nest errors to get a nice trace of what went wrong
throw new IncompabibleTypesError(typeA, typeB, e);
}
throw e;
}
return substitutions;
};

View file

@ -1,5 +1,5 @@
import assert from "node:assert"; import assert from "node:assert";
import { assignFn, makeGeneric, unify, UnifyError } from "../lib/generics/generics.js"; import { assignFn, unify, UnifyError } from "../lib/generics/generics.js";
import { getDefaultTypeParser } from "../lib/parser/type_parser.js"; import { getDefaultTypeParser } from "../lib/parser/type_parser.js";
import { prettyT } from "../lib/util/pretty.js"; import { prettyT } from "../lib/util/pretty.js";
@ -12,7 +12,7 @@ assert.equal(
prettyT( prettyT(
unify( unify(
mkType("(a -> Int)"), mkType("(a -> Int)"),
makeGeneric(() => mkType("[Bool] -> Int")), mkType("[Bool] -> Int"),
) )
), ),
// expected // expected

108
tests/unify.js Normal file
View file

@ -0,0 +1,108 @@
import assert from "node:assert";
import { getDefaultTypeParser } from "../lib/parser/type_parser.js";
import { IncompabibleTypesError, subsitutionsEqual, unify } from "../lib/generics/unify.js";
const assertSubsitutionsEqual = (m1,m2) => {
if (!subsitutionsEqual(m1,m2)) {
throw new Error(`substitutions differ:\n m1 = ${prettySS(m1)}\n m2 = ${prettySS(m2)}`);
}
};
const mkType = getDefaultTypeParser();
assertSubsitutionsEqual(
unify(
mkType("Int"),
mkType("Int"),
),
new Map(),
);
assert.throws(
() => {
unify(
mkType("Int"),
mkType("Bool")
);
},
IncompabibleTypesError,
);
assertSubsitutionsEqual(
unify(
mkType("a -> Int"),
mkType("b -> b"),
),
new Map([
[mkType("a").symbol, mkType("Int")],
[mkType("b").symbol, mkType("Int")],
]),
);
assertSubsitutionsEqual(
unify(
mkType("(a -> Int)"),
mkType("[Bool] -> Int"),
),
new Map([
[mkType("a").symbol, mkType("[Bool]")],
]),
);
assertSubsitutionsEqual(
unify(
mkType("(a -> a) -> b"),
mkType("(Bool -> Bool) -> c"),
),
new Map([
[mkType("a").symbol, mkType("Bool")],
[mkType("c").symbol, mkType("b")],
]),
);
assert.throws(
() => {
unify(
mkType("a -> (a -> Ordering)"),
mkType("[b] -> b"),
);
},
IncompabibleTypesError,
);
assertSubsitutionsEqual(
unify(
mkType("[a] -> (Int -> a)"),
mkType("b -> c"),
),
new Map([
[mkType("b").symbol, mkType("[a]")],
[mkType("c").symbol, mkType("Int -> a")],
]),
);
assert.throws(
() => {
unify(
mkType("[a] -> (Int -> a)"),
mkType("b -> (c -> b)"),
);
},
IncompabibleTypesError,
// String,
);
assertSubsitutionsEqual(
unify(
mkType("a -> b -> a -> b"),
mkType("c -> c -> d -> e"),
),
new Map([
[mkType("b").symbol, mkType("a")],
[mkType("c").symbol, mkType("a")],
[mkType("d").symbol, mkType("a")],
[mkType("e").symbol, mkType("a")],
]),
);