From dfe03eab6e8a4d73802fa87d2513b1d495e15ffb Mon Sep 17 00:00:00 2001 From: Joeri Exelmans Date: Mon, 19 May 2025 13:18:58 +0200 Subject: [PATCH] rewrite, simply and "power-up" unification --- lib/generics/generics.js | 202 ++++++++----------------------------- lib/generics/low_level.js | 150 +++++++++++++++++++++++++++ lib/util/rbtree_wrapper.js | 10 ++ tests/generics.js | 24 ++--- 4 files changed, 211 insertions(+), 175 deletions(-) create mode 100644 lib/generics/low_level.js diff --git a/lib/generics/generics.js b/lib/generics/generics.js index eed5c41..bab3de2 100644 --- a/lib/generics/generics.js +++ b/lib/generics/generics.js @@ -1,10 +1,10 @@ import { inspect } from "node:util"; -import { eqType, getSymbol } from "../primitives/type.js"; -import { zip } from "../util/util.js"; -import { pretty, prettyT } from '../util/pretty.js'; -import { isTypeVar, TYPE_VARS } from "../primitives/typevars.js"; -import { inspectType } from "../meta/type_constructor.js"; +import { inspectType, makeTypeConstructor } from "../meta/type_constructor.js"; +import { getSymbol } from "../primitives/type.js"; +import { isTypeVar, TYPE_VARS, UNBOUND_SYMBOLS } from "../primitives/typevars.js"; import { symbolFunction } from "../structures/type_constructors.js"; +import { prettyT } from '../util/pretty.js'; +import { reduceUnif, unifyLL } from "./low_level.js"; // helper for creating generic types // for instance, the type: @@ -35,137 +35,16 @@ const _occurring = stack => type => { // Get set of type variables in type. export const occurring = _occurring([]); -// Merge 2 substitution-mappings, uni-directional. -const mergeOneWay = (m1, m2) => { - const m1copy = new Map(m1); - const m2copy = new Map(m2); - for (const [symbol1, typ1] of m1copy) { - if (m2copy.has(getSymbol(typ1))) { - // typ1 is a typeVar for which we also have a substitution - // -> fold substitutions - m1copy.set(symbol1, m2.get(getSymbol(typ1))); - m2copy.delete(getSymbol(typ1)); - return [false, m1copy, m2copy]; - } - } - return [true, m1copy, m2copy]; // stable -}; - -const checkConflict = (m1, m2) => { - for (const [symbol1, typ1] of m1) { - if (m2.has(symbol1)) { - const other = m2.get(symbol1); - if (!eqType(typ1, other)) { - throw new Error(`conflicting substitution: ${pretty(typ1)}vs. ${pretty(other)}`); - } - } - } -}; - -// Merge 2 substitution-mappings, bi-directional. -export const mergeTwoWay = (m1, m2) => { - // console.log("mergeTwoWay", {m1, m2}); - checkConflict(m1, m2); - // checkConflict(m2, m1); // <- don't think this is necessary... - // actually merge - let remaining = 2; - while (remaining > 0) { - // notice we swap m2 and m1, so the rewriting can happen both ways: - let stable; - [stable, m2, m1] = mergeOneWay(m1, m2); - remaining -= stable; - } - const result = new Map([...m1, ...m2]); - // console.log("mergeTwoWay result =", result); - return result; -}; - export class UnifyError extends Error {} export class NotAFunctionError extends Error {} -// Thanks to Hans for pointing out that this algorithm exactly like "Unification" in Prolog (hence the function name): -// https://www.dai.ed.ac.uk/groups/ssp/bookpages/quickprolog/node12.html -// -// Parameters: -// typeVars: all the type variables in both fType and aType -// fType, aType: generic types to unify -// fStack, aStack: internal use. -export const __unify = (fType, aType, fStack=[], aStack=[]) => { - // console.log("__unify", {typeVars, fType: prettyT(fType), aType: prettyT(aType), fStack, aStack}); - if (isTypeVar(fType)) { - // simplest case: formalType is a type paramater - // => substitute with actualType - // console.log(`assign ${prettyT(aType)} to ${prettyT(fType)}`); - return { - substitutions: new Map([[getSymbol(fType), aType]]), - type: aType, - }; - } - if (isTypeVar(aType)) { - // same as above, but in the other direction - // console.log(`assign ${prettyT(fType)} to ${prettyT(aType)}`); - return { - substitutions: new Map([[getSymbol(aType), fType]]), - type: fType, - }; - } - - // recursively unify - if (fType.symbol !== aType.symbol) { - throw new UnifyError(`cannot unify ${prettyT(fType)} and ${prettyT(aType)}`); - } - - const fTag = fStack.length; - const aTag = aStack.length; - - const unifications = - zip(fType.params, aType.params) - .map(([getFParam, getAParam]) => { - const fParam = getFParam(fTag); - const aParam = getAParam(aTag); - // type recursively points to an enclosing type that we've already seen - if (fStack[fParam] !== aStack[aParam]) { - // note that both are also allowed not to be mapped (undefined) - throw new UnifyError("cannot unify: types differ in their recursion"); - } - if (fStack[fParam] !== undefined) { - const type = fStack[fParam]; - return () => ({ - substitutions: new Map(), - type, - }); - } - return parent => __unify(fParam, aParam, - [...fStack, parent], - [...aStack, parent]); - }); - - const unifiedParams = unifications.map(getParam => { - return parent => getParam(parent).type; - }); - - const unifiedSubstitutions = unifications.reduce((acc, getParam) => { - const self = Symbol(); // dirty, just need something unique - const paramSubstitutions = getParam(self).substitutions; - const substitutions = mergeTwoWay(acc, paramSubstitutions); - return substitutions; - }, new Map()); - - return { - substitutions: unifiedSubstitutions, - type: { - symbol: fType.symbol, - params: unifiedParams, - [inspect.custom]: inspectType, - }, - }; -}; - export const unify = (fType, aType) => { [fType, aType] = recomputeTypeVars([fType, aType]); - const {type, substitutions} = __unify(fType, aType); - // console.log('unification complete! substitutions:', substitutions); - return recomputeTypeVars([type])[0]; + const unification = unifyLL(fType, aType); + const substitutions = reduceUnif(unification); + const uType = substitute(fType, // or aType, doesn't matter here + substitutions); + return recomputeTypeVars([uType])[0]; }; export const substitute = (type, substitutions, stack=[]) => { @@ -188,41 +67,42 @@ export const substitute = (type, substitutions, stack=[]) => { }; }; -export const assignFn = (funType, paramType) => { - const [inType, inSubst, outType, outSubst] = assignFnSubstitutions(funType, paramType); - // return recomputeTypeVars([outType])[0]; - return outType; -}; - -// same as above, but also returns the substitutions that took place -export const assignFnSubstitutions = (funType, paramType, skip=0) => { +export const assignFn = (funType, paramType, skip=0) => { + // Precondition if (getSymbol(funType) !== symbolFunction) { throw new NotAFunctionError(`${prettyT(funType)} is not a function type!`); } - const [[refunType, funS], [reparamType, paramS]] = recomputeTypeVarsSubstitutions([funType, paramType], skip); - const [inType, outType] = refunType.params.map(p => p(refunType)); - const {type: newInType, substitutions} = __unify(inType, reparamType); - const totalParamSubstitutions = mergeTwoWay(substitutions, paramS); - const newOutType = substitute(outType, substitutions); - const [[finalOutType, outsubst]] = recomputeTypeVarsSubstitutions([newOutType], skip); - const totalOutSubstitutions = mergeTwoWay(funS, outsubst); - return [newInType, totalParamSubstitutions, finalOutType, totalOutSubstitutions]; -}; + + // Step 1: Very important: Function and parameter type may have overlapping type variables, so we recompute them to make them non-overlapping: + const [funType1, paramType1] = recomputeTypeVars([funType, paramType]); + + // Step 2: Get input and output type of function + const [inType1, outType1] = funType1.params.map(p => p(funType1)); + + // Step 3: Unify parameter type with input type + const unifInType1 = unifyLL(inType1, paramType1); + + // Step 4: Substitute typevars in output type + const substInType1 = reduceUnif(unifInType1); + const reducedOutType1 = substitute(outType1, substInType1); + + // Step 5: 'Normalize' output type + const [outType] = recomputeTypeVars([reducedOutType1], skip); + return outType; +} // Ensures that no type variables overlap -export const recomputeTypeVars = types => { - return recomputeTypeVarsSubstitutions(types) - .map(([newType, _subst]) => newType); -}; - -export const recomputeTypeVarsSubstitutions = (types, skip=0) => { +export const recomputeTypeVars = (types, skip=0) => { let nextIdx = skip; return types.map(type => { - const substitutions = new Map(); - const typeVars = occurring(type); - for (const typeVar of typeVars) { - substitutions.set(typeVar, TYPE_VARS[nextIdx++]); - } - return [substitute(type, substitutions), substitutions]; - }); + const substitutions = new Map(); + const typeVars = occurring(type); + for (const typeVar of typeVars) { + const idx = nextIdx++; + if (typeVar !== UNBOUND_SYMBOLS[idx]) { + substitutions.set(typeVar, TYPE_VARS[idx]); + } + } + return substitute(type, substitutions); + }); }; diff --git a/lib/generics/low_level.js b/lib/generics/low_level.js new file mode 100644 index 0000000..76a3abd --- /dev/null +++ b/lib/generics/low_level.js @@ -0,0 +1,150 @@ +import { compareTypes } from "../compare/type.js"; +import { getHumanReadableName } from "../primitives/symbol.js"; +import { eqType, getSymbol } from "../primitives/type.js"; +import { isTypeVar } from "../primitives/typevars.js"; +import { emptySet, add, has } from "../structures/set.js"; +import { prettyT } from "../util/pretty.js"; +import { zip } from "../util/util.js"; +import { UnifyError } from "./generics.js"; + +const emptyTypeSet = emptySet(compareTypes); + +// Low-level unify +// Assumes that if types variables in typeA and typeB are overlapping, they are the same, so it may be necessary to re-compute type variables before calling this function. +export const unifyLL = (typeA, typeB, stackA = [], stackB = []) => { + if (eqType(typeA)(typeB)) { + return new Map(); + } + + if (isTypeVar(typeA) || isTypeVar(typeB)) { + const unifA = isTypeVar(typeA) + ? new Map([[getSymbol(typeA), add(emptyTypeSet)(typeB)]]) + : new Map(); + + const unifB = isTypeVar(typeB) + ? new Map([[getSymbol(typeB), add(emptyTypeSet)(typeA)]]) + : new Map(); + + return mergeUnifications(unifA, unifB); + } + + // recursively unify + if (typeA.symbol !== typeB.symbol) { + throw new UnifyError(`cannot unify ${prettyT(typeA)} and ${prettyT(typeB)}`); + } + + const tagA = stackA.length; + const tagB = stackB.length; + + const unifParams = zip(typeA.params, typeB.params) + .map(([getParamA, getParamB]) => { + const paramA = getParamA(tagA); + const paramB = getParamB(tagB); + + // type recursively points to an enclosing type that we've already seen + if (stackA[paramA] !== stackB[paramB]) { + // note that both are also allowed not to be mapped (undefined) + throw new UnifyError("cannot unify: types differ in their recursion"); + } + if (stackA[paramA] !== undefined) { + // we've already seen this type, don't endlessly recurse: + return new Map(); + } + return unifyLL(paramA, paramB, + [...stackA, tagA], + [...stackB, tagB]); + }); + + return unifParams.reduce( + (acc, cur) => mergeUnifications(acc, cur), + new Map()); +}; + +// Given two unifications, try to merge them (may throw UnifyError). +// Useful when the same type variable occurs in multiple places, to see if there are conflicts. +export const mergeUnifications = (unifA, unifB) => { + const allSymbols = new Set([...unifA.keys(), ...unifB.keys()]); + const result = new Map(); + for (const symbol of allSymbols) { + const setOfTypesA = unifA.get(symbol) || emptyTypeSet; + const setOfTypesB = unifB.get(symbol) || emptyTypeSet; + let union = setOfTypesA; + for (const typeB of setOfTypesB.keys()) { + union = addIfSafe(union, typeB); + } + result.set(symbol, union); + } + // console.log(`mergeUnifications(${prettyU(unifA)}, ${prettyU(unifB)}) = ${prettyU(result)}`); + // return transitivelyMerge(result); + return result; +}; + +const transitivelyGrow = (unif) => { + let stable = true; + const result = new Map(); + for (const [symbol, setOfTypes] of unif) { + let newSetOfTypes = setOfTypes; + for (const type of setOfTypes.keys()) { + if (isTypeVar(type)) { + const haveTypes = unif.get(getSymbol(type)); + if (haveTypes) { + for (const transitiveType of haveTypes.keys()) { + if (!has(newSetOfTypes)(transitiveType)) { + newSetOfTypes = addIfSafe(newSetOfTypes, transitiveType); + stable = false; + } + } + } + } + } + result.set(symbol, newSetOfTypes); + } + // repeat until stable + return stable ? result : transitivelyGrow(result); +}; + +const addIfSafe = (setOfTypes, typeToAdd) => { + for (const alreadyHaveType of setOfTypes.keys()) { + if (!has(setOfTypes)(typeToAdd)) { + if (isTypeVar(alreadyHaveType) && isTypeVar(typeToAdd)) { + continue; // not a problem + } + // console.log('can unify', prettyT(typeToAdd), 'and', prettyT(alreadyHaveType), '?'); + unifyLL(alreadyHaveType, typeToAdd); // may throw + } + } + return add(setOfTypes)(typeToAdd); +}; + +// Given a non-conflicting, non-empty set of types, reduce it to a single type +export const reduce = (setOfTypes) => { + for (const type of setOfTypes.keys()) { + if (!isTypeVar(type)) { + // console.log('reduce', prettyST(setOfTypes), 'to', prettyT(type)); + return type; + } + } + // console.log('reduce', prettyST(setOfTypes), 'to', prettyT(setOfTypes.keys()[0])); + return setOfTypes.keys()[0]; +}; + +// Reduce a unification to a mapping: {symbol => Type} +// this mapping can then be used for substituting the typevars (=symbols) in a type by concrete types +export const reduceUnif = (unif) => { + // console.log('b4 grown:', prettyU(unif)); + const grown = transitivelyGrow(unif); + // console.log('grown:', prettyU(grown)); + const result = new Map([...grown] + .map(([symbol, types]) => + [symbol, reduce(types)])); + // console.log('reduce', prettyU(grown), 'to', result); + return result; +}; + +// For debugging +const prettyU = unif => { + return `{${[...unif].map(([symbol, types]) => `${getHumanReadableName(symbol)} => ${prettyST(types)}`).join(', ')}}`; +}; +const prettyST = st => { + return `(${st.keys().map(prettyT).join(',')})`; +}; diff --git a/lib/util/rbtree_wrapper.js b/lib/util/rbtree_wrapper.js index a328c37..dbc3fce 100644 --- a/lib/util/rbtree_wrapper.js +++ b/lib/util/rbtree_wrapper.js @@ -18,4 +18,14 @@ export class RBTreeWrapper { static new(compareFn) { return new RBTreeWrapper(createRBTree(compareFn)) } + + // only for debugging: + keys() { + return this.tree.keys; + } + + // only for debugging: + entries() { + return this.tree.keys.map(key => [key, this.tree.get(key)]); + } } diff --git a/tests/generics.js b/tests/generics.js index 72c040d..ccb8d52 100644 --- a/tests/generics.js +++ b/tests/generics.js @@ -1,8 +1,7 @@ import assert from "node:assert"; -import { assignFn, assignFnSubstitutions, makeGeneric, unify, UnifyError } from "../lib/generics/generics.js"; +import { assignFn, makeGeneric, unify, UnifyError } from "../lib/generics/generics.js"; import { getDefaultTypeParser } from "../lib/parser/type_parser.js"; import { prettyT } from "../lib/util/pretty.js"; -import { TYPE_VARS, UNBOUND_SYMBOLS } from "../lib/primitives/typevars.js"; const mkType = getDefaultTypeParser(); @@ -67,15 +66,12 @@ assert.throws( UnifyError, ); -const [inType, inSubst, outType, outSubst] = assignFnSubstitutions( - mkType("Int -> Int"), - mkType("b"), -); - -assert.equal(prettyT(inType), "Int"); -assert.equal(prettyT(outType), "Int"); -assert.equal(inSubst.size, 1); -assert.equal(prettyT( - inSubst.get(UNBOUND_SYMBOLS[1]) // b -), "Int") -assert.equal(outSubst.size, 0); \ No newline at end of file +assert.throws( + () => { + unify( + mkType("((a -> (a -> Ordering)) -> {a})"), + mkType("([a] -> a)"), + ) + }, + UnifyError, +); \ No newline at end of file