diff --git a/lib/generics/unify.js b/lib/generics/unify.js new file mode 100644 index 0000000..1896371 --- /dev/null +++ b/lib/generics/unify.js @@ -0,0 +1,168 @@ +import { eqType } from "../primitives/type.js"; +import { isTypeVar } from "../primitives/typevars.js"; +import { prettyT } from "../util/pretty.js"; +import { indent, zip } from "../util/util.js"; +import { compareStrings } from "../compare/primitives.js"; +import { getHumanReadableName } from "../primitives/symbol.js"; +import { occurring, substitute } from "./generics.js"; + +export const prettyS = (typevar, type) => { + return `${getHumanReadableName(typevar)} ↦ ${prettyT(type)}`; +} +export const prettySS = (rUni) => '{'+[...rUni].map(([symbol,type]) => `${prettyS(symbol,type)}`).join(', ')+'}'; + +export class IncompabibleTypesError extends Error { + constructor(typeA, typeB, nestedErr) { + const msg = `\nIncompatible types: ${prettyT(typeA)} and ${prettyT(typeB)}`; + if (nestedErr) { + super(msg + indent(nestedErr.message, 2)); + } + else { + super(msg); + } + } +} + +export class SubstitutionCycle extends Error { + constructor(typevar, type) { + super(`\nSubstitution cycle: ${getHumanReadableName(typevar)} ↦ ${prettyT(type)}`); + } +} + +export const subsitutionsEqual = (m1,m2) => { + if (m1.size !== m2.size ) return false; + for (const [key1,type1] of m1) { + if (!eqType(m2.get(key1))(type1)) return false; + } + return true; +}; + +// Partial ordering between types +// - deep-equal types are equal (e.g., Int == Int) +// - non-typevars are smaller than typevars (e.g., Int < a) +// - typevars with smaller letters are smaller (e.g., a < b) +// - if the symbols match, and if an ordering exists between all the parameters, then we take the ordering of the first non-equal parameter (e.g., (a,Int) < (b,Int)) +// returns: Ordering | undefined +const partialCompareTypes = (typeA, typeB) => { + if (isTypeVar(typeA)) { + if (isTypeVar(typeB)) { + return compareStrings(typeA.symbol)(typeB.symbol); + } + if (!isTypeVar(typeB)) { + return 1; + } + } + if (isTypeVar(typeB)) { + // console.log(typeB, 'is a typevar'); + return -1; // typeB is typevar, typeA is not + } + + if (typeA.symbol === typeB.symbol) { + let result = 0; + for (const [paramA, paramB] of zip(typeA.params, typeB.params)) { + const paramCmp = partialCompareTypes(paramA(typeA), paramB(typeB)); + if (paramCmp === undefined) return; // no ordering + result ||= paramCmp; + } + return result; + } +} + +const checkCycle = (typevar, type) => { + if (occurring(type).has(typevar)) { + throw new SubstitutionCycle(typevar, type); + } +} + +const addReduce = (substitutions, typevar, type) => { + // console.log('add ', prettyS(typevar, type)); + substitutions.set(typevar, type); +} + +const attemptReduce = (substitutions, typevar, type) => { + // assuming 'substitutions' is already reduced as much as possible, + // substitute all typevars in our type with the existing substitutions + const substType = substitute(type, substitutions); + // Check for cycles. For instance, the substitution + // a ↦ [a] + // is forbidden. Not sure if this is too strict, because on the other hand, we do support recursive types... + checkCycle(typevar, substType); + const overlappingType = substitutions.get(typevar); + if (overlappingType) { + const uni = unify(overlappingType, substType); + for (const [typevar, type] of uni) { + attemptReduce(substitutions, typevar, type); + } + const cmp = partialCompareTypes(substType, overlappingType); + if (cmp === -1) { + // our type "wins" (smaller than the existing type) + addReduce(substitutions, typevar, substType); + } + else { + // other type "wins" -> don't do anything + } + } + else { + // no overlap + addReduce(substitutions, typevar, substType); + } +} + +export const unify = (typeA, typeB) => { + const substitutions = new Map(); + try { + if (isTypeVar(typeA)) { + if (isTypeVar(typeB)) { + const cmp = partialCompareTypes(typeA, typeB); + if (cmp === -1) { + // console.log(prettyT(typeA), 'is smaller than', prettyT(typeB)); + return new Map([[typeB.symbol, typeA]]); + } + if (cmp === 1) { + // console.log(prettyT(typeB), 'is smaller than', prettyT(typeA)); + return new Map([[typeA.symbol, typeB]]); + } + return new Map(); // typevars are equal + } + // A is typevar, B is not + checkCycle(typeA.symbol, typeB); + return new Map([[typeA.symbol, typeB]]); + } + if (isTypeVar(typeB)) { + // B is typevar, A is not + checkCycle(typeB.symbol, typeA); + return new Map([[typeB.symbol, typeA]]); + } + + // A and B are not typevars + if (typeA.symbol !== typeB.symbol) { + throw new IncompabibleTypesError(typeA, typeB); + } + + const unifiedParams = zip(typeA.params, typeB.params) + .map(([getParamA, getParamB]) => { + const paramA = getParamA(typeA); + const paramB = getParamB(typeB); + // console.log('request...'); + return unify(paramA, paramB); + }); + + // merge substitutions + unifiedParams.forEach(subst => { + for (const [typevar, type] of subst) { + attemptReduce(substitutions, typevar, type); + } + }); + } + catch (e) { + if (e instanceof SubstitutionCycle) { + throw new IncompabibleTypesError(typeA, typeB, e); + } + if (e instanceof IncompabibleTypesError) { + // nest errors to get a nice trace of what went wrong + throw new IncompabibleTypesError(typeA, typeB, e); + } + throw e; + } + return substitutions; +}; diff --git a/tests/generics.js b/tests/generics.js index ccb8d52..9799dc6 100644 --- a/tests/generics.js +++ b/tests/generics.js @@ -1,5 +1,5 @@ import assert from "node:assert"; -import { assignFn, makeGeneric, unify, UnifyError } from "../lib/generics/generics.js"; +import { assignFn, unify, UnifyError } from "../lib/generics/generics.js"; import { getDefaultTypeParser } from "../lib/parser/type_parser.js"; import { prettyT } from "../lib/util/pretty.js"; @@ -12,7 +12,7 @@ assert.equal( prettyT( unify( mkType("(a -> Int)"), - makeGeneric(() => mkType("[Bool] -> Int")), + mkType("[Bool] -> Int"), ) ), // expected diff --git a/tests/unify.js b/tests/unify.js new file mode 100644 index 0000000..588de70 --- /dev/null +++ b/tests/unify.js @@ -0,0 +1,108 @@ +import assert from "node:assert"; + +import { getDefaultTypeParser } from "../lib/parser/type_parser.js"; +import { IncompabibleTypesError, subsitutionsEqual, unify } from "../lib/generics/unify.js"; + + +const assertSubsitutionsEqual = (m1,m2) => { + if (!subsitutionsEqual(m1,m2)) { + throw new Error(`substitutions differ:\n m1 = ${prettySS(m1)}\n m2 = ${prettySS(m2)}`); + } +}; + +const mkType = getDefaultTypeParser(); + +assertSubsitutionsEqual( + unify( + mkType("Int"), + mkType("Int"), + ), + new Map(), +); + +assert.throws( + () => { + unify( + mkType("Int"), + mkType("Bool") + ); + }, + IncompabibleTypesError, +); + +assertSubsitutionsEqual( + unify( + mkType("a -> Int"), + mkType("b -> b"), + ), + new Map([ + [mkType("a").symbol, mkType("Int")], + [mkType("b").symbol, mkType("Int")], + ]), +); + +assertSubsitutionsEqual( + unify( + mkType("(a -> Int)"), + mkType("[Bool] -> Int"), + ), + new Map([ + [mkType("a").symbol, mkType("[Bool]")], + ]), +); + +assertSubsitutionsEqual( + unify( + mkType("(a -> a) -> b"), + mkType("(Bool -> Bool) -> c"), + ), + new Map([ + [mkType("a").symbol, mkType("Bool")], + [mkType("c").symbol, mkType("b")], + ]), +); + +assert.throws( + () => { + unify( + mkType("a -> (a -> Ordering)"), + mkType("[b] -> b"), + ); + }, + IncompabibleTypesError, +); + +assertSubsitutionsEqual( + unify( + mkType("[a] -> (Int -> a)"), + mkType("b -> c"), + ), + new Map([ + [mkType("b").symbol, mkType("[a]")], + [mkType("c").symbol, mkType("Int -> a")], + ]), +); + +assert.throws( + () => { + unify( + mkType("[a] -> (Int -> a)"), + mkType("b -> (c -> b)"), + ); + }, + IncompabibleTypesError, + // String, +); + +assertSubsitutionsEqual( + unify( + mkType("a -> b -> a -> b"), + mkType("c -> c -> d -> e"), + ), + new Map([ + [mkType("b").symbol, mkType("a")], + [mkType("c").symbol, mkType("a")], + [mkType("d").symbol, mkType("a")], + [mkType("e").symbol, mkType("a")], + ]), +);