wrote new unify-function that always returns minimal set of substitutions
This commit is contained in:
parent
47786ae792
commit
68bd7cdb9f
3 changed files with 278 additions and 2 deletions
168
lib/generics/unify.js
Normal file
168
lib/generics/unify.js
Normal file
|
|
@ -0,0 +1,168 @@
|
||||||
|
import { eqType } from "../primitives/type.js";
|
||||||
|
import { isTypeVar } from "../primitives/typevars.js";
|
||||||
|
import { prettyT } from "../util/pretty.js";
|
||||||
|
import { indent, zip } from "../util/util.js";
|
||||||
|
import { compareStrings } from "../compare/primitives.js";
|
||||||
|
import { getHumanReadableName } from "../primitives/symbol.js";
|
||||||
|
import { occurring, substitute } from "./generics.js";
|
||||||
|
|
||||||
|
export const prettyS = (typevar, type) => {
|
||||||
|
return `${getHumanReadableName(typevar)} ↦ ${prettyT(type)}`;
|
||||||
|
}
|
||||||
|
export const prettySS = (rUni) => '{'+[...rUni].map(([symbol,type]) => `${prettyS(symbol,type)}`).join(', ')+'}';
|
||||||
|
|
||||||
|
export class IncompabibleTypesError extends Error {
|
||||||
|
constructor(typeA, typeB, nestedErr) {
|
||||||
|
const msg = `\nIncompatible types: ${prettyT(typeA)} and ${prettyT(typeB)}`;
|
||||||
|
if (nestedErr) {
|
||||||
|
super(msg + indent(nestedErr.message, 2));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
super(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class SubstitutionCycle extends Error {
|
||||||
|
constructor(typevar, type) {
|
||||||
|
super(`\nSubstitution cycle: ${getHumanReadableName(typevar)} ↦ ${prettyT(type)}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const subsitutionsEqual = (m1,m2) => {
|
||||||
|
if (m1.size !== m2.size ) return false;
|
||||||
|
for (const [key1,type1] of m1) {
|
||||||
|
if (!eqType(m2.get(key1))(type1)) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Partial ordering between types
|
||||||
|
// - deep-equal types are equal (e.g., Int == Int)
|
||||||
|
// - non-typevars are smaller than typevars (e.g., Int < a)
|
||||||
|
// - typevars with smaller letters are smaller (e.g., a < b)
|
||||||
|
// - if the symbols match, and if an ordering exists between all the parameters, then we take the ordering of the first non-equal parameter (e.g., (a,Int) < (b,Int))
|
||||||
|
// returns: Ordering | undefined
|
||||||
|
const partialCompareTypes = (typeA, typeB) => {
|
||||||
|
if (isTypeVar(typeA)) {
|
||||||
|
if (isTypeVar(typeB)) {
|
||||||
|
return compareStrings(typeA.symbol)(typeB.symbol);
|
||||||
|
}
|
||||||
|
if (!isTypeVar(typeB)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (isTypeVar(typeB)) {
|
||||||
|
// console.log(typeB, 'is a typevar');
|
||||||
|
return -1; // typeB is typevar, typeA is not
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeA.symbol === typeB.symbol) {
|
||||||
|
let result = 0;
|
||||||
|
for (const [paramA, paramB] of zip(typeA.params, typeB.params)) {
|
||||||
|
const paramCmp = partialCompareTypes(paramA(typeA), paramB(typeB));
|
||||||
|
if (paramCmp === undefined) return; // no ordering
|
||||||
|
result ||= paramCmp;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const checkCycle = (typevar, type) => {
|
||||||
|
if (occurring(type).has(typevar)) {
|
||||||
|
throw new SubstitutionCycle(typevar, type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const addReduce = (substitutions, typevar, type) => {
|
||||||
|
// console.log('add ', prettyS(typevar, type));
|
||||||
|
substitutions.set(typevar, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const attemptReduce = (substitutions, typevar, type) => {
|
||||||
|
// assuming 'substitutions' is already reduced as much as possible,
|
||||||
|
// substitute all typevars in our type with the existing substitutions
|
||||||
|
const substType = substitute(type, substitutions);
|
||||||
|
// Check for cycles. For instance, the substitution
|
||||||
|
// a ↦ [a]
|
||||||
|
// is forbidden. Not sure if this is too strict, because on the other hand, we do support recursive types...
|
||||||
|
checkCycle(typevar, substType);
|
||||||
|
const overlappingType = substitutions.get(typevar);
|
||||||
|
if (overlappingType) {
|
||||||
|
const uni = unify(overlappingType, substType);
|
||||||
|
for (const [typevar, type] of uni) {
|
||||||
|
attemptReduce(substitutions, typevar, type);
|
||||||
|
}
|
||||||
|
const cmp = partialCompareTypes(substType, overlappingType);
|
||||||
|
if (cmp === -1) {
|
||||||
|
// our type "wins" (smaller than the existing type)
|
||||||
|
addReduce(substitutions, typevar, substType);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// other type "wins" -> don't do anything
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// no overlap
|
||||||
|
addReduce(substitutions, typevar, substType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const unify = (typeA, typeB) => {
|
||||||
|
const substitutions = new Map();
|
||||||
|
try {
|
||||||
|
if (isTypeVar(typeA)) {
|
||||||
|
if (isTypeVar(typeB)) {
|
||||||
|
const cmp = partialCompareTypes(typeA, typeB);
|
||||||
|
if (cmp === -1) {
|
||||||
|
// console.log(prettyT(typeA), 'is smaller than', prettyT(typeB));
|
||||||
|
return new Map([[typeB.symbol, typeA]]);
|
||||||
|
}
|
||||||
|
if (cmp === 1) {
|
||||||
|
// console.log(prettyT(typeB), 'is smaller than', prettyT(typeA));
|
||||||
|
return new Map([[typeA.symbol, typeB]]);
|
||||||
|
}
|
||||||
|
return new Map(); // typevars are equal
|
||||||
|
}
|
||||||
|
// A is typevar, B is not
|
||||||
|
checkCycle(typeA.symbol, typeB);
|
||||||
|
return new Map([[typeA.symbol, typeB]]);
|
||||||
|
}
|
||||||
|
if (isTypeVar(typeB)) {
|
||||||
|
// B is typevar, A is not
|
||||||
|
checkCycle(typeB.symbol, typeA);
|
||||||
|
return new Map([[typeB.symbol, typeA]]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// A and B are not typevars
|
||||||
|
if (typeA.symbol !== typeB.symbol) {
|
||||||
|
throw new IncompabibleTypesError(typeA, typeB);
|
||||||
|
}
|
||||||
|
|
||||||
|
const unifiedParams = zip(typeA.params, typeB.params)
|
||||||
|
.map(([getParamA, getParamB]) => {
|
||||||
|
const paramA = getParamA(typeA);
|
||||||
|
const paramB = getParamB(typeB);
|
||||||
|
// console.log('request...');
|
||||||
|
return unify(paramA, paramB);
|
||||||
|
});
|
||||||
|
|
||||||
|
// merge substitutions
|
||||||
|
unifiedParams.forEach(subst => {
|
||||||
|
for (const [typevar, type] of subst) {
|
||||||
|
attemptReduce(substitutions, typevar, type);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
catch (e) {
|
||||||
|
if (e instanceof SubstitutionCycle) {
|
||||||
|
throw new IncompabibleTypesError(typeA, typeB, e);
|
||||||
|
}
|
||||||
|
if (e instanceof IncompabibleTypesError) {
|
||||||
|
// nest errors to get a nice trace of what went wrong
|
||||||
|
throw new IncompabibleTypesError(typeA, typeB, e);
|
||||||
|
}
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
return substitutions;
|
||||||
|
};
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import assert from "node:assert";
|
import assert from "node:assert";
|
||||||
import { assignFn, makeGeneric, unify, UnifyError } from "../lib/generics/generics.js";
|
import { assignFn, unify, UnifyError } from "../lib/generics/generics.js";
|
||||||
import { getDefaultTypeParser } from "../lib/parser/type_parser.js";
|
import { getDefaultTypeParser } from "../lib/parser/type_parser.js";
|
||||||
import { prettyT } from "../lib/util/pretty.js";
|
import { prettyT } from "../lib/util/pretty.js";
|
||||||
|
|
||||||
|
|
@ -12,7 +12,7 @@ assert.equal(
|
||||||
prettyT(
|
prettyT(
|
||||||
unify(
|
unify(
|
||||||
mkType("(a -> Int)"),
|
mkType("(a -> Int)"),
|
||||||
makeGeneric(() => mkType("[Bool] -> Int")),
|
mkType("[Bool] -> Int"),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
// expected
|
// expected
|
||||||
|
|
|
||||||
108
tests/unify.js
Normal file
108
tests/unify.js
Normal file
|
|
@ -0,0 +1,108 @@
|
||||||
|
import assert from "node:assert";
|
||||||
|
|
||||||
|
import { getDefaultTypeParser } from "../lib/parser/type_parser.js";
|
||||||
|
import { IncompabibleTypesError, subsitutionsEqual, unify } from "../lib/generics/unify.js";
|
||||||
|
|
||||||
|
|
||||||
|
const assertSubsitutionsEqual = (m1,m2) => {
|
||||||
|
if (!subsitutionsEqual(m1,m2)) {
|
||||||
|
throw new Error(`substitutions differ:\n m1 = ${prettySS(m1)}\n m2 = ${prettySS(m2)}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const mkType = getDefaultTypeParser();
|
||||||
|
|
||||||
|
assertSubsitutionsEqual(
|
||||||
|
unify(
|
||||||
|
mkType("Int"),
|
||||||
|
mkType("Int"),
|
||||||
|
),
|
||||||
|
new Map(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert.throws(
|
||||||
|
() => {
|
||||||
|
unify(
|
||||||
|
mkType("Int"),
|
||||||
|
mkType("Bool")
|
||||||
|
);
|
||||||
|
},
|
||||||
|
IncompabibleTypesError,
|
||||||
|
);
|
||||||
|
|
||||||
|
assertSubsitutionsEqual(
|
||||||
|
unify(
|
||||||
|
mkType("a -> Int"),
|
||||||
|
mkType("b -> b"),
|
||||||
|
),
|
||||||
|
new Map([
|
||||||
|
[mkType("a").symbol, mkType("Int")],
|
||||||
|
[mkType("b").symbol, mkType("Int")],
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
|
||||||
|
assertSubsitutionsEqual(
|
||||||
|
unify(
|
||||||
|
mkType("(a -> Int)"),
|
||||||
|
mkType("[Bool] -> Int"),
|
||||||
|
),
|
||||||
|
new Map([
|
||||||
|
[mkType("a").symbol, mkType("[Bool]")],
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
|
||||||
|
assertSubsitutionsEqual(
|
||||||
|
unify(
|
||||||
|
mkType("(a -> a) -> b"),
|
||||||
|
mkType("(Bool -> Bool) -> c"),
|
||||||
|
),
|
||||||
|
new Map([
|
||||||
|
[mkType("a").symbol, mkType("Bool")],
|
||||||
|
[mkType("c").symbol, mkType("b")],
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert.throws(
|
||||||
|
() => {
|
||||||
|
unify(
|
||||||
|
mkType("a -> (a -> Ordering)"),
|
||||||
|
mkType("[b] -> b"),
|
||||||
|
);
|
||||||
|
},
|
||||||
|
IncompabibleTypesError,
|
||||||
|
);
|
||||||
|
|
||||||
|
assertSubsitutionsEqual(
|
||||||
|
unify(
|
||||||
|
mkType("[a] -> (Int -> a)"),
|
||||||
|
mkType("b -> c"),
|
||||||
|
),
|
||||||
|
new Map([
|
||||||
|
[mkType("b").symbol, mkType("[a]")],
|
||||||
|
[mkType("c").symbol, mkType("Int -> a")],
|
||||||
|
]),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert.throws(
|
||||||
|
() => {
|
||||||
|
unify(
|
||||||
|
mkType("[a] -> (Int -> a)"),
|
||||||
|
mkType("b -> (c -> b)"),
|
||||||
|
);
|
||||||
|
},
|
||||||
|
IncompabibleTypesError,
|
||||||
|
// String,
|
||||||
|
);
|
||||||
|
|
||||||
|
assertSubsitutionsEqual(
|
||||||
|
unify(
|
||||||
|
mkType("a -> b -> a -> b"),
|
||||||
|
mkType("c -> c -> d -> e"),
|
||||||
|
),
|
||||||
|
new Map([
|
||||||
|
[mkType("b").symbol, mkType("a")],
|
||||||
|
[mkType("c").symbol, mkType("a")],
|
||||||
|
[mkType("d").symbol, mkType("a")],
|
||||||
|
[mkType("e").symbol, mkType("a")],
|
||||||
|
]),
|
||||||
|
);
|
||||||
Loading…
Add table
Add a link
Reference in a new issue