wrote new unify-function that always returns minimal set of substitutions
This commit is contained in:
parent
47786ae792
commit
68bd7cdb9f
3 changed files with 278 additions and 2 deletions
168
lib/generics/unify.js
Normal file
168
lib/generics/unify.js
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
import { eqType } from "../primitives/type.js";
|
||||
import { isTypeVar } from "../primitives/typevars.js";
|
||||
import { prettyT } from "../util/pretty.js";
|
||||
import { indent, zip } from "../util/util.js";
|
||||
import { compareStrings } from "../compare/primitives.js";
|
||||
import { getHumanReadableName } from "../primitives/symbol.js";
|
||||
import { occurring, substitute } from "./generics.js";
|
||||
|
||||
export const prettyS = (typevar, type) => {
|
||||
return `${getHumanReadableName(typevar)} ↦ ${prettyT(type)}`;
|
||||
}
|
||||
export const prettySS = (rUni) => '{'+[...rUni].map(([symbol,type]) => `${prettyS(symbol,type)}`).join(', ')+'}';
|
||||
|
||||
export class IncompabibleTypesError extends Error {
|
||||
constructor(typeA, typeB, nestedErr) {
|
||||
const msg = `\nIncompatible types: ${prettyT(typeA)} and ${prettyT(typeB)}`;
|
||||
if (nestedErr) {
|
||||
super(msg + indent(nestedErr.message, 2));
|
||||
}
|
||||
else {
|
||||
super(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class SubstitutionCycle extends Error {
|
||||
constructor(typevar, type) {
|
||||
super(`\nSubstitution cycle: ${getHumanReadableName(typevar)} ↦ ${prettyT(type)}`);
|
||||
}
|
||||
}
|
||||
|
||||
export const subsitutionsEqual = (m1,m2) => {
|
||||
if (m1.size !== m2.size ) return false;
|
||||
for (const [key1,type1] of m1) {
|
||||
if (!eqType(m2.get(key1))(type1)) return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// Partial ordering between types
|
||||
// - deep-equal types are equal (e.g., Int == Int)
|
||||
// - non-typevars are smaller than typevars (e.g., Int < a)
|
||||
// - typevars with smaller letters are smaller (e.g., a < b)
|
||||
// - if the symbols match, and if an ordering exists between all the parameters, then we take the ordering of the first non-equal parameter (e.g., (a,Int) < (b,Int))
|
||||
// returns: Ordering | undefined
|
||||
const partialCompareTypes = (typeA, typeB) => {
|
||||
if (isTypeVar(typeA)) {
|
||||
if (isTypeVar(typeB)) {
|
||||
return compareStrings(typeA.symbol)(typeB.symbol);
|
||||
}
|
||||
if (!isTypeVar(typeB)) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
if (isTypeVar(typeB)) {
|
||||
// console.log(typeB, 'is a typevar');
|
||||
return -1; // typeB is typevar, typeA is not
|
||||
}
|
||||
|
||||
if (typeA.symbol === typeB.symbol) {
|
||||
let result = 0;
|
||||
for (const [paramA, paramB] of zip(typeA.params, typeB.params)) {
|
||||
const paramCmp = partialCompareTypes(paramA(typeA), paramB(typeB));
|
||||
if (paramCmp === undefined) return; // no ordering
|
||||
result ||= paramCmp;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
const checkCycle = (typevar, type) => {
|
||||
if (occurring(type).has(typevar)) {
|
||||
throw new SubstitutionCycle(typevar, type);
|
||||
}
|
||||
}
|
||||
|
||||
const addReduce = (substitutions, typevar, type) => {
|
||||
// console.log('add ', prettyS(typevar, type));
|
||||
substitutions.set(typevar, type);
|
||||
}
|
||||
|
||||
const attemptReduce = (substitutions, typevar, type) => {
|
||||
// assuming 'substitutions' is already reduced as much as possible,
|
||||
// substitute all typevars in our type with the existing substitutions
|
||||
const substType = substitute(type, substitutions);
|
||||
// Check for cycles. For instance, the substitution
|
||||
// a ↦ [a]
|
||||
// is forbidden. Not sure if this is too strict, because on the other hand, we do support recursive types...
|
||||
checkCycle(typevar, substType);
|
||||
const overlappingType = substitutions.get(typevar);
|
||||
if (overlappingType) {
|
||||
const uni = unify(overlappingType, substType);
|
||||
for (const [typevar, type] of uni) {
|
||||
attemptReduce(substitutions, typevar, type);
|
||||
}
|
||||
const cmp = partialCompareTypes(substType, overlappingType);
|
||||
if (cmp === -1) {
|
||||
// our type "wins" (smaller than the existing type)
|
||||
addReduce(substitutions, typevar, substType);
|
||||
}
|
||||
else {
|
||||
// other type "wins" -> don't do anything
|
||||
}
|
||||
}
|
||||
else {
|
||||
// no overlap
|
||||
addReduce(substitutions, typevar, substType);
|
||||
}
|
||||
}
|
||||
|
||||
export const unify = (typeA, typeB) => {
|
||||
const substitutions = new Map();
|
||||
try {
|
||||
if (isTypeVar(typeA)) {
|
||||
if (isTypeVar(typeB)) {
|
||||
const cmp = partialCompareTypes(typeA, typeB);
|
||||
if (cmp === -1) {
|
||||
// console.log(prettyT(typeA), 'is smaller than', prettyT(typeB));
|
||||
return new Map([[typeB.symbol, typeA]]);
|
||||
}
|
||||
if (cmp === 1) {
|
||||
// console.log(prettyT(typeB), 'is smaller than', prettyT(typeA));
|
||||
return new Map([[typeA.symbol, typeB]]);
|
||||
}
|
||||
return new Map(); // typevars are equal
|
||||
}
|
||||
// A is typevar, B is not
|
||||
checkCycle(typeA.symbol, typeB);
|
||||
return new Map([[typeA.symbol, typeB]]);
|
||||
}
|
||||
if (isTypeVar(typeB)) {
|
||||
// B is typevar, A is not
|
||||
checkCycle(typeB.symbol, typeA);
|
||||
return new Map([[typeB.symbol, typeA]]);
|
||||
}
|
||||
|
||||
// A and B are not typevars
|
||||
if (typeA.symbol !== typeB.symbol) {
|
||||
throw new IncompabibleTypesError(typeA, typeB);
|
||||
}
|
||||
|
||||
const unifiedParams = zip(typeA.params, typeB.params)
|
||||
.map(([getParamA, getParamB]) => {
|
||||
const paramA = getParamA(typeA);
|
||||
const paramB = getParamB(typeB);
|
||||
// console.log('request...');
|
||||
return unify(paramA, paramB);
|
||||
});
|
||||
|
||||
// merge substitutions
|
||||
unifiedParams.forEach(subst => {
|
||||
for (const [typevar, type] of subst) {
|
||||
attemptReduce(substitutions, typevar, type);
|
||||
}
|
||||
});
|
||||
}
|
||||
catch (e) {
|
||||
if (e instanceof SubstitutionCycle) {
|
||||
throw new IncompabibleTypesError(typeA, typeB, e);
|
||||
}
|
||||
if (e instanceof IncompabibleTypesError) {
|
||||
// nest errors to get a nice trace of what went wrong
|
||||
throw new IncompabibleTypesError(typeA, typeB, e);
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
return substitutions;
|
||||
};
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import assert from "node:assert";
|
||||
import { assignFn, makeGeneric, unify, UnifyError } from "../lib/generics/generics.js";
|
||||
import { assignFn, unify, UnifyError } from "../lib/generics/generics.js";
|
||||
import { getDefaultTypeParser } from "../lib/parser/type_parser.js";
|
||||
import { prettyT } from "../lib/util/pretty.js";
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ assert.equal(
|
|||
prettyT(
|
||||
unify(
|
||||
mkType("(a -> Int)"),
|
||||
makeGeneric(() => mkType("[Bool] -> Int")),
|
||||
mkType("[Bool] -> Int"),
|
||||
)
|
||||
),
|
||||
// expected
|
||||
|
|
|
|||
108
tests/unify.js
Normal file
108
tests/unify.js
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
import assert from "node:assert";
|
||||
|
||||
import { getDefaultTypeParser } from "../lib/parser/type_parser.js";
|
||||
import { IncompabibleTypesError, subsitutionsEqual, unify } from "../lib/generics/unify.js";
|
||||
|
||||
|
||||
const assertSubsitutionsEqual = (m1,m2) => {
|
||||
if (!subsitutionsEqual(m1,m2)) {
|
||||
throw new Error(`substitutions differ:\n m1 = ${prettySS(m1)}\n m2 = ${prettySS(m2)}`);
|
||||
}
|
||||
};
|
||||
|
||||
const mkType = getDefaultTypeParser();
|
||||
|
||||
assertSubsitutionsEqual(
|
||||
unify(
|
||||
mkType("Int"),
|
||||
mkType("Int"),
|
||||
),
|
||||
new Map(),
|
||||
);
|
||||
|
||||
assert.throws(
|
||||
() => {
|
||||
unify(
|
||||
mkType("Int"),
|
||||
mkType("Bool")
|
||||
);
|
||||
},
|
||||
IncompabibleTypesError,
|
||||
);
|
||||
|
||||
assertSubsitutionsEqual(
|
||||
unify(
|
||||
mkType("a -> Int"),
|
||||
mkType("b -> b"),
|
||||
),
|
||||
new Map([
|
||||
[mkType("a").symbol, mkType("Int")],
|
||||
[mkType("b").symbol, mkType("Int")],
|
||||
]),
|
||||
);
|
||||
|
||||
assertSubsitutionsEqual(
|
||||
unify(
|
||||
mkType("(a -> Int)"),
|
||||
mkType("[Bool] -> Int"),
|
||||
),
|
||||
new Map([
|
||||
[mkType("a").symbol, mkType("[Bool]")],
|
||||
]),
|
||||
);
|
||||
|
||||
assertSubsitutionsEqual(
|
||||
unify(
|
||||
mkType("(a -> a) -> b"),
|
||||
mkType("(Bool -> Bool) -> c"),
|
||||
),
|
||||
new Map([
|
||||
[mkType("a").symbol, mkType("Bool")],
|
||||
[mkType("c").symbol, mkType("b")],
|
||||
]),
|
||||
);
|
||||
|
||||
assert.throws(
|
||||
() => {
|
||||
unify(
|
||||
mkType("a -> (a -> Ordering)"),
|
||||
mkType("[b] -> b"),
|
||||
);
|
||||
},
|
||||
IncompabibleTypesError,
|
||||
);
|
||||
|
||||
assertSubsitutionsEqual(
|
||||
unify(
|
||||
mkType("[a] -> (Int -> a)"),
|
||||
mkType("b -> c"),
|
||||
),
|
||||
new Map([
|
||||
[mkType("b").symbol, mkType("[a]")],
|
||||
[mkType("c").symbol, mkType("Int -> a")],
|
||||
]),
|
||||
);
|
||||
|
||||
assert.throws(
|
||||
() => {
|
||||
unify(
|
||||
mkType("[a] -> (Int -> a)"),
|
||||
mkType("b -> (c -> b)"),
|
||||
);
|
||||
},
|
||||
IncompabibleTypesError,
|
||||
// String,
|
||||
);
|
||||
|
||||
assertSubsitutionsEqual(
|
||||
unify(
|
||||
mkType("a -> b -> a -> b"),
|
||||
mkType("c -> c -> d -> e"),
|
||||
),
|
||||
new Map([
|
||||
[mkType("b").symbol, mkType("a")],
|
||||
[mkType("c").symbol, mkType("a")],
|
||||
[mkType("d").symbol, mkType("a")],
|
||||
[mkType("e").symbol, mkType("a")],
|
||||
]),
|
||||
);
|
||||
Loading…
Add table
Add a link
Reference in a new issue