-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsampler.py
More file actions
153 lines (126 loc) · 4.61 KB
/
sampler.py
File metadata and controls
153 lines (126 loc) · 4.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from abc import ABC, abstractmethod
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
List as TList,
Optional,
Tuple,
TypeVar,
Union,
)
import copy
import numpy as np
import vose
from synth.syntax.type_system import List, Type
T = TypeVar("T")
U = TypeVar("U")
class Sampler(ABC, Generic[T]):
@abstractmethod
def sample(self, **kwargs: Any) -> T:
pass
def compose(self, f: Callable[[T], T]) -> "Sampler[T]":
return ComposedSampler(self, f)
class ComposedSampler(Sampler[T]):
def __init__(self, sampler: Sampler[T], f: Callable[[T], T]) -> None:
self.sampler = sampler
self.f = f
def sample(self, **kwargs: Any) -> T:
return self.f(self.sampler.sample(**kwargs))
class LexiconSampler(Sampler[U]):
def __init__(
self,
lexicon: TList[U],
probabilites: Optional[Iterable[float]] = None,
seed: Optional[int] = None,
) -> None:
super().__init__()
self.lexicon = copy.deepcopy(lexicon)
if isinstance(probabilites, np.ndarray) or probabilites:
filled_probabilities = probabilites
else:
filled_probabilities = [1 / len(self.lexicon) for _ in lexicon]
self.sampler = vose.Sampler(np.asarray(filled_probabilities), seed=seed)
def sample(self, **kwargs: Any) -> U:
index: int = self.sampler.sample()
return self.lexicon[index]
class RequestSampler(Sampler[U], ABC):
def sample(self, **kwargs: Any) -> U:
return self.sample_for(**kwargs)
@abstractmethod
def sample_for(self, type: Type, **kwargs: Any) -> U:
pass
def compose_with_type_mapper(
self, f: Callable[[Type], Type]
) -> "RequestSampler[U]":
return TypeMappedRequestSampler(self, f)
class TypeMappedRequestSampler(RequestSampler[U]):
def __init__(self, sampler: RequestSampler[U], f: Callable[[Type], Type]) -> None:
self.sampler = sampler
self.f = f
def sample_for(self, type: Type, **kwargs: Any) -> U:
return self.sampler.sample_for(self.f(type), **kwargs)
class ListSampler(RequestSampler[Union[TList, U]]):
def __init__(
self,
element_sampler: Sampler[U],
probabilities: Union[
TList[float], TList[Tuple[int, float]], RequestSampler[int]
],
max_depth: int = -1,
seed: Optional[int] = None,
) -> None:
super().__init__()
self.max_depth = max_depth
self.element_sampler = element_sampler
if isinstance(probabilities, RequestSampler):
self.length_sampler = probabilities
self.sampler = None
else:
correct_prob: TList[Tuple[int, float]] = probabilities # type: ignore
if not isinstance(probabilities[0], tuple):
correct_prob = [(i + 1, p) for i, p in enumerate(probabilities)] # type: ignore
self._length_mapping = [n for n, _ in correct_prob]
self.sampler = vose.Sampler(
np.array([p for _, p in correct_prob]), seed=seed
)
def __gen_length__(self, type: Type) -> int:
if self.sampler:
i: int = self.sampler.sample()
return self._length_mapping[i]
else:
return self.length_sampler.sample(type=type)
def sample_for(self, type: Type, **kwargs: Any) -> Union[TList, U]:
assert self.max_depth < 0 or type.depth() <= self.max_depth
if isinstance(type, List):
sampler: Sampler = self
if not isinstance(type.element_type, List):
sampler = self.element_sampler
length: int = self.__gen_length__(type)
return [
sampler.sample(type=type.element_type, **kwargs) for _ in range(length)
]
else:
return self.element_sampler.sample(type=type, **kwargs)
class UnionSampler(RequestSampler[Any]):
def __init__(
self, samplers: Dict[Type, Sampler], fallback: Optional[Sampler] = None
) -> None:
super().__init__()
self.samplers = samplers
self.fallback = fallback
def sample_for(self, type: Type, **kwargs: Any) -> Any:
sampler = self.samplers.get(type, self.fallback)
assert (
sampler
), f"UnionSampler: No sampler found for type {type}({hash(type)}) in {self}"
return sampler.sample(type=type, **kwargs)
def __str__(self) -> str:
s = (
f"UnionSampler(fallback={self.fallback}, samplers="
+ ", ".join([f"{k}({hash(k)}):{v}" for k, v in self.samplers.items()])
+ ")"
)
return s