[[IT知识]] 深入解析Spark SQL中的UDF与UDAF函数

[复制链接]
查看: 51|回复: 0
发表于 2025-1-25 09:15:02 | 显示全部楼层 | 阅读模式
易博V9下载

深入解析Spark SQL中的UDF与UDAF函数

前言

UDF、UDAF、UDTF都是用户自定义函数,用户可以通过

  1. spark.udf
复制代码
功能添加自定义函数,实现自定义功能。

UDF:用户自定义函数(User Defined Function),一行输入一行输出。

UDAF:用户自定义聚合函数(User Defined Aggregate Function),多行输入一行输出。

UDTF:用户自定义表函数(User Defined Table Generating Function),一行输入多行输出。

聚合函数和普通函数的区别:普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。

本篇将介绍UDF和UDAF函数。

一、概念 UDF

UDF(User-Defined-Function),也就是最基本的函数,它提供了SQL中对字段转换的功能,不涉及聚合操作。

适用场景:UDF使用频率极高,对于单条记录进行比较复杂的操作,使用内置函数无法完成或者比较复杂的情况都比较适合使用UDF。

UDAF

UDAF(User-Defined-Aggregate-Function)函数是用户自定义的聚合函数,为Spark SQL 提供对数据集的聚合功能。

类似于max()、min()、count()等功能,只不过自定义的功能是根据具体的业务功能来确定的。

因为DataFrame是弱类型的,DataSet是强类型,所以自定义的 UDAF也提供了两种实现,一个是弱类型的一个是强类型的(不常用)。

误区

我们可能下意识的认为UDAF是需要和group by一起使用的,实际上UDAF可以跟group by一起使用,也可以不跟group by一起使用,这个其实比较好理解,联想到mysql中的max、min等函数,可以:

  1. sql
  2. select max(age) from person group by address;
复制代码

表示根据address字段分组,然后求每个分组的最大值,这时候的分组有很多个,使用这个函数对每个分组进行处理,也可以:

  1. sql
  2. select max(age) from person;
复制代码

这种情况可以将整张表看做是一个分组,然后在这个分组(实际上就是一整张表)中求最大值。所以聚合函数实际上是对分组做处理,而不关心分组中记录的具体数量。

二、具体用法 2.1 UDF用法

具体步骤:

  • 实现UDF,可以是case class,可以是匿名类
  • 注册到spark,将类绑定到一个name,后续会使用这个name来调用函数
  • 在sql语句中调用注册的name调用UDF

代码示例:

  1. scala
  2. import org.apache.spark.rdd.RDD
  3. import org.apache.spark.sql.SparkSession
  4. /**
  5. * @author lilinchao
  6. * @date 2021/7/15
  7. * @description 1.0
  8. **/
  9. object SparkSQL_UDF {
  10. def main(args: Array[String]): Unit = {
  11. val spark = SparkSession.builder().master("local[*]").appName("udf").getOrCreate()
  12. //后面要用到toDF,必须导入这个隐式转换
  13. import spark.implicits._
  14. //引入数据源
  15. val rdd: RDD[(String, String)] = spark.sparkContext.parallelize(Seq(("010","zhagnsan"),("0020","王五"),("00345","赵六")))
  16. //将集合转成dataFrame,并创建临时表
  17. rdd.toDF("id","name").createOrReplaceTempView("person")
  18. //注册自定义udf函数
  19. spark.udf.register("fillZero",fillZero _)
  20. //自定义匿名函数,统计字符串长度
  21. spark.udf.register("strLen",(str: String) => str.length())
  22. //没有加自定义函数
  23. spark.sql("select id,name from person").show()
  24. //加了自定义udf函数
  25. spark.sql("select fillZero(id),name,strLen(name) from person").show()
  26. spark.close()
  27. }
  28. /**
  29. * 补全Id
  30. */
  31. def fillZero(id:String):String = {
  32. "0"*(8-id.length) id
  33. }
  34. }
复制代码

直接对列使用UDF

在sql语句中使用比较麻烦,还要进行注册,可以定义一个UDF然后将它直接应用到某个列上:

  1. scala
  2. import org.apache.spark.sql.{SparkSession, functions}
  3. /**
  4. * @author lilinchao
  5. * @date 2021/7/15
  6. * @description 1.0
  7. **/
  8. object Spark01_SparkSQL_UDF2 {
  9. def main(args: Array[String]): Unit = {
  10. val spark = SparkSession.builder().master("local[*]").appName("udf").getOrCreate()
  11. import spark.implicits._
  12. val ds = Seq((1, "zhangsan"), (2, "lisi")).toDF("id", "name")
  13. //自定义匿名函数,小写转大写
  14. val toUpperCase = functions.udf((s: String) => s.toUpperCase)
  15. ds.withColumn("name", toUpperCase('name)).show()
  16. spark.close()
  17. }
  18. }
复制代码
2.2 UDAF用法

数据准备:

user.json文件

  1. json
  2. {"id": 1001, "name": "王小帅", "sex": "man", "age": 22}
  3. {"id": 1002, "name": "岳小林", "sex": "man", "age": 16}
  4. {"id": 1003, "name": "邱小峰", "sex": "man", "age": 18}
  5. {"id": 1004, "name": "刘小明", "sex": "woman", "age": 17}
  6. {"id": 1005, "name": "张小飞", "sex": "woman", "age": 19}
  7. {"id": 1006, "name": "李小刀", "sex": "woman", "age": 20}
复制代码
1. 继承
  1. UserDefinedAggregateFunction
复制代码

具体步骤:

  • 自定义类继承
    1. UserDefinedAggregateFunction
    复制代码
    ,对每个阶段方法做实现
  • 在spark中注册UDAF,为其绑定一个名称
  • 在sql语句中使用上面绑定的名字调用

下面写一个计算平均值的UDAF例子

首先定义一个类继承UserDefinedAggregateFunction:

  1. scala
  2. import org.apache.spark.sql.Row
  3. import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
  4. import org.apache.spark.sql.types._
  5. import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
  6. /**
  7. * @author lilinchao
  8. * @date 2021/7/15
  9. * @description 1.0
  10. **/
  11. object AverageUserDefinedAggregateFunction extends UserDefinedAggregateFunction{
  12. // 聚合函数的输入数据结构
  13. override def inputSchema: StructType = StructType(StructField("input", LongType) :: Nil)
  14. // 缓存区数据结构
  15. override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
  16. // 聚合函数返回值数据结构
  17. override def dataType: DataType = DoubleType
  18. // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
  19. override def deterministic: Boolean = true
  20. // 初始化缓冲区
  21. override def initialize(buffer: MutableAggregationBuffer): Unit = {
  22. buffer(0) = 0L
  23. buffer(1) = 0L
  24. }
  25. // 给聚合函数传入一条新数据进行处理
  26. override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  27. if (input.isNullAt(0)) return
  28. buffer(0) = buffer.getLong(0) input.getLong(0)
  29. buffer(1) = buffer.getLong(1) 1
  30. }
  31. // 合并聚合函数缓冲区
  32. override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
  33. buffer1(0) = buffer1.getLong(0) buffer2.getLong(0)
  34. buffer1(1) = buffer1.getLong(1) buffer2.getLong(1)
  35. }
  36. // 计算最终结果
  37. override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1)
  38. }
复制代码

在主函数中进行注册并完成调用

  1. scala
  2. import org.apache.spark.sql.SparkSession
  3. /**
  4. * @author lilinchao
  5. * @date 2021/7/15
  6. * @description 1.0
  7. **/
  8. object SparkSql_UDAFDemo {
  9. def main(args: Array[String]): Unit = {
  10. val spark = SparkSession.builder().master("local[*]").appName("SparkUDAF").getOrCreate()
  11. spark.read.json("input/user.json").createOrReplaceTempView("user")
  12. spark.udf.register("u_avg", AverageUserDefinedAggregateFunction)
  13. // 将整张表看做是一个分组对求所有人的平均年龄
  14. spark.sql("select count(1) as count, u_avg(age) as avg_age from user").show()
  15. // 按照性别分组求平均年龄
  16. spark.sql("select sex, count(1) as count, u_avg(age) as avg_age from user group by sex").show()
  17. }
  18. }
复制代码
2. 继承Aggregator

继承Aggregator这个类,优点是可以带类型

  1. scala
  2. import org.apache.spark.sql.expressions.Aggregator
  3. import org.apache.spark.sql.{Encoder, Encoders}
  4. /**
  5. * @author lilinchao
  6. * @date 2021/7/15
  7. * @description 计算平均值
  8. **/
  9. object AverageAggregator extends Aggregator[User, Average, Double]{
  10. // 初始化buffer
  11. override def zero: Average = Average(0L, 0L)
  12. // 处理一条新的记录
  13. override def reduce(b: Average, a: User): Average = {
  14. b.sum = a.age
  15. b.count = 1L
  16. b
  17. }
  18. // 合并聚合buffer
  19. override def merge(b1: Average, b2: Average): Average = {
  20. b1.sum = b2.sum
  21. b1.count = b2.count
  22. b1
  23. }
  24. // 减少中间数据传输
  25. override def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
  26. override def bufferEncoder: Encoder[Average] = Encoders.product
  27. // 最终输出结果的类型
  28. override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
  29. }
  30. /**
  31. * 计算平均值过程中使用的Buffer
  32. *
  33. * @param sum
  34. * @param count
  35. */
  36. case class Average(var sum: Long, var count: Long) {
  37. }
  38. case class User(id: Long, name: String, sex: String, age: Long) {
  39. }
复制代码

主函数调用

  1. scala
  2. import org.apache.spark.sql.SparkSession
  3. /**
  4. * @author lilinchao
  5. * @date 2021/7/15
  6. * @description 1.0
  7. **/
  8. object SparkSql_UDAFDemo02 {
  9. def main(args: Array[String]): Unit = {
  10. val spark = SparkSession.builder().master("local[*]").appName("SparkUDAF").getOrCreate()
  11. import spark.implicits._
  12. val user = spark.read.json("input/user.json").as[User]
  13. user.select(AverageAggregator.toColumn.name("avg")).show()
  14. }
  15. }
复制代码
易博软件介绍
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

1、请认真发帖,禁止回复纯表情,纯数字等无意义的内容!帖子内容不要太简单!
2、提倡文明上网,净化网络环境!抵制低俗不良违法有害信息。
3、如果你对主帖作者的帖子不屑一顾的话,请勿回帖。谢谢合作!
3、问答求助区发帖求助后,如有其他用户热心帮您解决问题后,请自觉点击设为最佳答案按钮。

 
 
QQ在线客服
QQ技术支持
工作时间:
8:00-18:00
软著登字:
1361266号
官方微信扫一扫
weixin

QQ|小黑屋|Archiver|慈众营销 ( 粤ICP备15049986号 )|网站地图

自动发帖软件 | 自动发帖器 | 营销推广软件 | 网络营销工具 | 网络营销软件 | 网站推广工具 | 网络推广软件 | 网络推广工具 | 网页推广软件 | 信息发布软件 | 网站推广工具 | 网页推广软件

Powered by Discuz! X3.4   © 2012-2020 Comsenz Inc.  慈众科技 - Collect from 深圳吉宝泰佛文化有限公司 公司地址:罗湖区黄贝街道深南东路集浩大厦A1403

返回顶部 返回列表