我自定义一个spark sql的函数,中间值产生的是数组,但是设置中间类型为ArrayType会报红
把ArrayType该父类DataType不报红,但是运行报错
难道不可以产生的中间类型为Array吗?求大神解答,感谢!!!
下面是全部代码
class GameDuration(startDayLong: Long, endDayLong: Long) extends UserDefinedAggregateFunction {
//输入的数据类型
override def inputSchema: StructType = StructType(
StructField("eventType", IntegerType) ::
StructField("timestamp", LongType) :: Nil
)
//产生的中间结果类型
override def bufferSchema: StructType = StructType(
StructField("list", ArrayType) :: Nil
)
//最终返回的结果类型
override def dataType: DataType = LongType
override def deterministic: Boolean = true
//指定初始值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(1) = ArrayBuffer[Long]()
}
//每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val eventType = input.getInt(0)
val timestamp = input.getLong(1)
var tmp = 0L
if (eventType == 3) tmp = timestamp
else tmp = -timestamp
buffer(1) = buffer.getAs[ArrayBuffer[Long]](1) += tmp
}
//全局聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
}
//最终计算
override def evaluate(buffer: Row): Long = {
val arr: ArrayBuffer[Long] = buffer.getAs[ArrayBuffer[Long]](1)
if(arr(0) > 0) arr += -startDayLong
if(arr.last < 0) arr += endDayLong
arr.sum
}
}
还有这是我调用该函数,不知道是否正确啊