Go 泛型
泛型(Generics)是 Go 1.18 引入的重要特性,它允许编写能够处理多种类型的代码,而无需为每种类型重复编写相同的逻辑。泛型通过类型参数实现,让代码更加通用和可复用。
为什么需要泛型?
在泛型出现之前,Go 开发者面临一个困境:如果想编写一个适用于多种类型的函数,通常有以下几种方式:
问题:代码重复
假设需要编写一个求和函数,适用于整数和浮点数:
// 没有泛型时,需要为每种类型写一个函数
func SumInts(numbers []int) int {
total := 0
for _, n := range numbers {
total += n
}
return total
}
func SumFloats(numbers []float64) float64 {
var total float64
for _, n := range numbers {
total += n
}
return total
}
这两个函数逻辑完全相同,只是类型不同,这就是代码重复。
问题:interface 的局限
另一种方式是使用 interface{},但这会失去类型安全:
func Sum(numbers interface{}) interface{} {
// 需要运行时类型断言,编译器无法检查类型错误
switch v := numbers.(type) {
case []int:
// ...
case []float64:
// ...
}
}
这种方式的缺点是:
- 运行时才能发现类型错误
- 代码冗长且难以维护
- 性能有损失(需要类型断言)
泛型的解决方案
泛型让编译器在编译时就能检查类型,同时保持代码的通用性:
func Sum[T int | float64](numbers []T) T {
var total T
for _, n := range numbers {
total += n
}
return total
}
// 使用
Sum([]int{1, 2, 3}) // 返回 6
Sum([]float64{1.1, 2.2}) // 返回 3.3
基本语法
泛型函数
泛型函数在函数名后使用方括号 [] 声明类型参数:
func 函数名[类型参数](参数) 返回值 {
// 函数体
}
类型参数的声明格式:类型参数名 约束
// T 是类型参数,int | float64 是约束
func Print[T int | float64](value T) {
fmt.Println(value)
}
多个类型参数
可以声明多个类型参数:
// 两个类型参数 K 和 V
func Map[K, V any](key K, value V) {
fmt.Printf("Key: %v, Value: %v\n", key, value)
}
// 调用
Map("name", "张三") // K=string, V=string
Map(1, "one") // K=int, V=string
类型约束
类型约束定义了类型参数允许的类型集合。约束使用接口类型表示:
// any 表示允许任何类型
func Identity[T any](v T) T {
return v
}
// int | float64 表示只允许 int 或 float64
func Add[T int | float64](a, b T) T {
return a + b
}
内置约束
Go 提供了几个常用的内置约束:
any 约束
any 是 interface{} 的别名,表示允许任何类型:
func PrintAny[T any](v T) {
fmt.Println(v)
}
// 可以传入任何类型
PrintAny(42)
PrintAny("hello")
PrintAny([]int{1, 2, 3})
comparable 约束
comparable 约束允许使用 == 和 != 运算符的类型:
// 包含所有可比较的类型(除了 map、slice、function)
func Contains[T comparable](slice []T, target T) bool {
for _, v := range slice {
if v == target {
return true
}
}
return false
}
// 使用
Contains([]int{1, 2, 3}, 2) // true
Contains([]string{"a", "b"}, "c") // false
comparable 包含的类型:
- 所有基本类型(int、float、string、bool 等)
- 指针
- 通道
- 接口
- 数组(元素可比较)
- 结构体(所有字段可比较)
comparable 不包含的类型:
- slice
- map
- function
自定义约束
接口约束
可以定义自己的约束接口:
// 定义数字类型约束
type Number interface {
int | int8 | int16 | int32 | int64 |
float32 | float64
}
func Sum[T Number](numbers []T) T {
var total T
for _, n := range numbers {
total += n
}
return total
}
// 使用
Sum([]int{1, 2, 3}) // 6
Sum([]float64{1.5, 2.5}) // 4.0
联合约束(Union)
使用 | 表示类型联合:
// 定义支持的类型集合
type Signed interface {
~int | ~int8 | ~int16 | ~int32 | ~int64
}
type Unsigned interface {
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64
}
type Integer interface {
Signed | Unsigned
}
底层类型约束(~)
~ 表示底层类型,允许基于某个类型的自定义类型:
type MyInt int
type MyFloat float64
// 没有 ~,只允许 int 和 float64
func Strict[T int | float64](v T) {}
// 有 ~,允许 int 和所有底层类型是 int 的类型
func Flexible[T ~int | ~float64](v T) {}
// 使用
var myInt MyInt = 10
// Strict(myInt) // 编译错误
Flexible(myInt) // 可以
方法约束
约束可以要求类型实现特定方法:
type Stringer interface {
String() string
}
// T 必须实现 String() 方法
func Print[T Stringer](v T) {
fmt.Println(v.String())
}
// 实现 Stringer 接口
type Person struct {
Name string
}
func (p Person) String() string {
return "Person: " + p.Name
}
// 使用
Print(Person{Name: "张三"}) // 输出: Person: 张三
组合约束
约束可以嵌套组合:
// 基础约束
type Ordered interface {
~int | ~float64 | ~string
}
// 组合约束
type OrderedNumber interface {
Ordered
Number // 假设 Number 已定义
}
// 复杂约束组合
type Serializer interface {
MarshalJSON() ([]byte, error)
UnmarshalJSON([]byte) error
}
泛型类型
泛型结构体
// 泛型链表节点
type Node[T any] struct {
Value T
Next *Node[T]
}
// 使用
head := &Node[int]{Value: 1}
head.Next = &Node[int]{Value: 2}
// 泛型栈
type Stack[T any] struct {
items []T
}
func (s *Stack[T]) Push(item T) {
s.items = append(s.items, item)
}
func (s *Stack[T]) Pop() (T, bool) {
if len(s.items) == 0 {
var zero T
return zero, false
}
item := s.items[len(s.items)-1]
s.items = s.items[:len(s.items)-1]
return item, true
}
func (s *Stack[T]) Len() int {
return len(s.items)
}
// 使用
stack := &Stack[string]{}
stack.Push("hello")
stack.Push("world")
item, _ := stack.Pop() // "world"
泛型 Map
type Map[K comparable, V any] struct {
data map[K]V
}
func NewMap[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{
data: make(map[K]V),
}
}
func (m *Map[K, V]) Set(key K, value V) {
m.data[key] = value
}
func (m *Map[K, V]) Get(key K) (V, bool) {
v, ok := m.data[key]
return v, ok
}
func (m *Map[K, V]) Delete(key K) {
delete(m.data, key)
}
func (m *Map[K, V]) Keys() []K {
keys := make([]K, 0, len(m.data))
for k := range m.data {
keys = append(keys, k)
}
return keys
}
// 使用
m := NewMap[string, int]()
m.Set("age", 25)
age, ok := m.Get("age")
泛型切片
type Slice[T any] []T
// 过滤
func (s Slice[T]) Filter(predicate func(T) bool) Slice[T] {
result := make(Slice[T], 0)
for _, v := range s {
if predicate(v) {
result = append(result, v)
}
}
return result
}
// 映射
func MapSlice[T, U any](s []T, f func(T) U) []U {
result := make([]U, len(s))
for i, v := range s {
result[i] = f(v)
}
return result
}
// 使用
nums := Slice[int]{1, 2, 3, 4, 5}
evens := nums.Filter(func(n int) bool {
return n%2 == 0
})
// evens = [2, 4]
strs := MapSlice([]int{1, 2, 3}, func(n int) string {
return fmt.Sprintf("Number: %d", n})
// strs = ["Number: 1", "Number: 2", "Number: 3"]
泛型接口
// 泛型容器接口
type Container[T any] interface {
Add(item T)
Remove() (T, bool)
Len() int
}
// 泛型比较器
type Comparator[T any] interface {
Compare(a, b T) int // 返回 -1, 0, 1
}
// 泛型排序函数
func Sort[T any](slice []T, cmp Comparator[T]) {
sort.Slice(slice, func(i, j int) bool {
return cmp.Compare(slice[i], slice[j]) < 0
})
}
// 实现比较器
type IntComparator struct{}
func (c IntComparator) Compare(a, b int) int {
if a < b {
return -1
} else if a > b {
return 1
}
return 0
}
// 使用
Sort([]int{3, 1, 2}, IntComparator{})
类型推断
Go 编译器可以推断类型参数,简化调用代码:
函数参数推断
func Identity[T any](v T) T {
return v
}
// 显式指定类型参数
result1 := Identity[int](42)
// 编译器推断类型参数
result2 := Identity(42) // 推断为 int
约束推断
type Number interface {
int | float64
}
func Add[T Number](a, b T) T {
return a + b
}
// 编译器从参数推断 T 为 int
Add(1, 2)
// 编译器从参数推断 T 为 float64
Add(1.5, 2.5)
不能推断的情况
func New[T any]() *T {
return new(T)
}
// 必须显式指定类型
p := New[int]() // 无法推断
常用泛型模式
泛型 Option 模式
type Option[T any] struct {
value *T
}
func Some[T any](v T) Option[T] {
return Option[T]{value: &v}
}
func None[T any]() Option[T] {
return Option[T]{}
}
func (o Option[T]) IsSome() bool {
return o.value != nil
}
func (o Option[T]) IsNone() bool {
return o.value == nil
}
func (o Option[T]) Unwrap() T {
if o.value == nil {
panic("unwrap on None")
}
return *o.value
}
func (o Option[T]) UnwrapOr(defaultValue T) T {
if o.value == nil {
return defaultValue
}
return *o.value
}
// 使用
name := Some("张三")
age := None[int]()
fmt.Println(name.IsSome()) // true
fmt.Println(age.IsNone()) // true
fmt.Println(name.Unwrap()) // "张三"
fmt.Println(age.UnwrapOr(0)) // 0
泛型 Result 模式
type Result[T any] struct {
value T
err error
}
func Ok[T any](v T) Result[T] {
return Result[T]{value: v}
}
func Err[T any](err error) Result[T] {
return Result[T]{err: err}
}
func (r Result[T]) IsOk() bool {
return r.err == nil
}
func (r Result[T]) IsErr() bool {
return r.err != nil
}
func (r Result[T]) Unwrap() T {
if r.err != nil {
panic(r.err)
}
return r.value
}
func (r Result[T]) UnwrapOr(defaultValue T) T {
if r.err != nil {
return defaultValue
}
return r.value
}
func (r Result[T]) Error() error {
return r.err
}
// 使用示例
func Divide(a, b float64) Result[float64] {
if b == 0 {
return Err[float64](fmt.Errorf("division by zero"))
}
return Ok(a / b)
}
result := Divide(10, 3)
if result.IsOk() {
fmt.Println(result.Unwrap())
} else {
fmt.Println(result.Error())
}
泛型集合操作
// 去重
func Deduplicate[T comparable](slice []T) []T {
seen := make(map[T]bool)
result := make([]T, 0)
for _, v := range slice {
if !seen[v] {
seen[v] = true
result = append(result, v)
}
}
return result
}
// 分组
func GroupBy[T any, K comparable](slice []T, keyFunc func(T) K) map[K][]T {
groups := make(map[K][]T)
for _, v := range slice {
key := keyFunc(v)
groups[key] = append(groups[key], v)
}
return groups
}
// 查找
func Find[T any](slice []T, predicate func(T) bool) (T, bool) {
for _, v := range slice {
if predicate(v) {
return v, true
}
}
var zero T
return zero, false
}
// 转换
func Transform[T, U any](slice []T, transform func(T) U) []U {
result := make([]U, len(slice))
for i, v := range slice {
result[i] = transform(v)
}
return result
}
// 使用
nums := []int{1, 2, 2, 3, 3, 3}
unique := Deduplicate(nums) // [1, 2, 3]
type Person struct {
Name string
Age int
}
people := []Person{
{"张三", 25},
{"李四", 30},
{"王五", 25},
}
byAge := GroupBy(people, func(p Person) int {
return p.Age
})
// map[25:[张三 王五] 30:[李四]]
泛型最佳实践
1. 保持约束简单
// 好:约束简单明了
func Max[T Ordered](a, b T) T {
if a > b {
return a
}
return b
}
// 不好:过于复杂的约束
func ComplexFunc[T interface {
~int | ~float64
String() string
comparable
}](v T) {}
2. 优先使用标准约束
// 使用内置约束
func Contains[T comparable](slice []T, target T) bool
// 自定义约束仅在必要时
type MyConstraint interface {
~int | ~float64
}
3. 选择合适的类型参数名
// 常用命名约定
// T - 通用类型参数
// K - 键类型
// V - 值类型
// E - 元素类型
func Map[K comparable, V any](m map[K]V) {}
func Slice[E any](s []E) {}
4. 避免过度泛型化
// 不需要泛型的情况
func PrintString(s string) {
fmt.Println(s)
}
// 不要这样
func Print[T string](s T) {
fmt.Println(s)
}
5. 考虑性能影响
// 泛型在编译时会生成具体类型的代码
// 每个类型参数都会生成一份代码
// 如果类型很多,可能增加二进制大小
// 使用接口可能更合适的情况
func Process(data io.Reader) error {
// ...
}
泛型与接口的区别
| 特性 | 泛型 | 接口 |
|---|---|---|
| 类型检查 | 编译时 | 编译时 |
| 性能 | 无运行时开销 | 有虚方法调用开销 |
| 代码生成 | 为每种类型生成代码 | 共享代码 |
| 适用场景 | 操作数据结构 | 定义行为契约 |
// 泛型适合:操作数据、集合
func Map[T, U any](s []T, f func(T) U) []U
// 接口适合:定义行为
type Writer interface {
Write(p []byte) (n int, err error)
}
小结
本章学习了 Go 泛型的核心概念和实践:
- 基本语法:类型参数声明、约束定义
- 内置约束:
any、comparable - 自定义约束:接口约束、联合约束、底层类型约束
- 泛型类型:泛型结构体、泛型容器
- 类型推断:简化调用代码
- 常用模式:Option、Result、集合操作
- 最佳实践:保持简单、选择合适场景
练习
- 实现一个泛型
Stack和Queue数据结构 - 实现泛型函数:
Min、Max、Contains、IndexOf - 实现泛型
Pair[K, V]类型 - 使用泛型重写一个你之前用
interface{}实现的函数 - 实现泛型排序算法